import os
import shutil
import random
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from collections import defaultdict, Counter
from torch_geometric.utils import degree
from torch_geometric.transforms import OneHotDegree
import seaborn as sns
from sklearn.manifold import TSNE
from torch_geometric.utils import to_undirected

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)

class TopologyDataset(InMemoryDataset):
    """Topology classification dataset generator"""
    
    def __init__(self, root, mode='train', transform=None, pre_transform=None, pre_filter=None, 
                 use_degree_features=True, min_nodes=None, max_nodes=None, samples_per_class=None):
        self.mode = mode
        self.use_degree_features = use_degree_features
        
        if min_nodes is None or max_nodes is None or samples_per_class is None:
            if mode == 'train':
                self.min_nodes, self.max_nodes = 20, 50
                self.samples_per_class = 500
            elif mode == 'val':
                self.min_nodes, self.max_nodes = 20, 50  
                self.samples_per_class = 50
            else:  # test
                self.min_nodes, self.max_nodes = 20, 50  
                self.samples_per_class = 50
        else:
            self.min_nodes = min_nodes
            self.max_nodes = max_nodes
            self.samples_per_class = samples_per_class
            
        self._num_classes = 6  
        self.class_names = [
            "Cyclic", "Geometric", "Community", "Hierarchical", 
            "Bottleneck", "Multi-core"
        ]
        
        
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
        return [f'dummy_{self.mode}.txt']
    
    @property
    def processed_file_names(self):
        return [f'topology_{self.mode}.pt']
    
    def download(self):
        if not os.path.exists(self.raw_dir):
            os.makedirs(self.raw_dir)
        with open(os.path.join(self.raw_dir, f'dummy_{self.mode}.txt'), 'w') as f:
            f.write('dummy')
    
    def process(self):
        data_list = []
        
        # Generate samples for each class
        for class_idx in range(self._num_classes):
            for _ in tqdm(range(self.samples_per_class), desc=f"Generating {self.mode} class {class_idx}"):
                n_nodes = random.randint(self.min_nodes, self.max_nodes)
                
                # Generate different graphs based on class
                if class_idx == 0:
                    G = self.generate_cyclic(n_nodes)
                elif class_idx == 1:
                    G = self.generate_random_geometric(n_nodes)
                elif class_idx == 2:
                    G = self.generate_community(n_nodes)
                elif class_idx == 3:
                    G = self.generate_hierarchical_hub(n_nodes)
                elif class_idx == 4:
                    G = self.generate_bottleneck(n_nodes)
                elif class_idx == 5:
                    G = self.generate_multi_core(n_nodes)
                
                # Ensure it's a connected graph
                if not nx.is_connected(G):
                    largest_cc = max(nx.connected_components(G), key=len)
                    G = G.subgraph(largest_cc).copy()
                
                # Normalize node numbering
                G = nx.convert_node_labels_to_integers(G)
                
                # Convert to PyG data format
                edge_index = torch.tensor(list(G.edges)).t().contiguous()
                edge_index = to_undirected(edge_index)
                x = torch.ones(G.number_of_nodes(), 1)  # Initial node features
                y = torch.tensor([class_idx], dtype=torch.long)
                
                data = Data(x=x, edge_index=edge_index, y=y)
                data_list.append(data)
        
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]
            
        # Calculate maximum degree for one-hot encoding
        if self.use_degree_features:
            max_degree = 0
            for data in data_list:
                deg = degree(data.edge_index[0], data.num_nodes, dtype=torch.long)
                max_degree = 100#max(max_degree, int(deg.max()))
            
            # Apply one-hot degree feature transformation
            degree_transform = OneHotDegree(max_degree)
            data_list = [degree_transform(data) for data in data_list]
        
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
            
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def generate_argg_topology(self, n_nodes, r_inner, r_outer, r_connect):
        """Generate annular random geometric graph topology structure"""
        G = nx.Graph()
        G.add_nodes_from(range(n_nodes))

        # Coordinates for generating topology (not used for visualization)
        temp_positions = {}

        # Distribute nodes in annular region
        for i in range(n_nodes):
            theta = 2 * np.pi * random.random()
            u = random.random()
            r = np.sqrt(r_inner**2 + u * (r_outer**2 - r_inner**2))

            x = r * np.cos(theta)
            y = r * np.sin(theta)
            temp_positions[i] = (x, y)

        # Connect nodes based on distance
        for i in range(n_nodes):
            for j in range(i+1, n_nodes):
                x1, y1 = temp_positions[i]
                x2, y2 = temp_positions[j]

                dist = np.sqrt((x1-x2)**2 + (y1-y2)**2)

                if dist <= r_connect:
                    G.add_edge(i, j)
        
        return G
    
    def generate_cyclic(self, n_nodes):
        """Generate cyclic topology using annular random geometric graph (ARGG)"""
        # Adjust parameters to control average degree
        target_min_degree = 3.0
        target_max_degree = 16.0
        
        # Try different parameter combinations until finding appropriate average degree
        max_attempts = 10  # Maximum of 10 attempts
        best_G = None
        best_degree_diff = float('inf')
        best_params = {}
        
        for attempt in range(max_attempts):
            # Adjust parameter range based on attempt number
            if attempt < 3:
                # Initial parameter range
                r_inner = random.uniform(0.9, 1.2)
                r_outer = r_inner + random.uniform(0.1, 0.3)
                r_connect = random.uniform(0.5, 0.65)
            elif attempt < 6:
                # If previous attempts had too low average degree, increase connection radius
                r_inner = random.uniform(0.8, 1.1)
                r_outer = r_inner + random.uniform(0.1, 0.25)
                r_connect = random.uniform(0.6, 0.75)
            else:
                # Further adjust parameters
                r_inner = random.uniform(0.7, 1.0)
                r_outer = r_inner + random.uniform(0.05, 0.2)
                r_connect = random.uniform(0.65, 0.8)
            
            # Generate graph
            G = self.generate_argg_topology(
                n_nodes=n_nodes,
                r_inner=r_inner,
                r_outer=r_outer,
                r_connect=r_connect
            )
            
            # Ensure graph is connected
            if not nx.is_connected(G):
                components = list(nx.connected_components(G))
                largest_cc = max(components, key=len)
                
                # Connect other components to largest component
                for comp in components:
                    if comp != largest_cc:
                        node1 = random.choice(list(comp))
                        node2 = random.choice(list(largest_cc))
                        G.add_edge(node1, node2)
            
            # Calculate average degree
            avg_degree = 2 * G.number_of_edges() / n_nodes
            
            # Check if within target range
            if target_min_degree <= avg_degree <= target_max_degree:
                best_G = G
                best_params = {'r_inner': r_inner, 'r_outer': r_outer, 
                            'r_connect': r_connect, 'avg_degree': avg_degree}
                break
            
            # Otherwise, save result closest to target
            degree_diff = min(abs(avg_degree - target_min_degree), 
                            abs(avg_degree - target_max_degree))
            if degree_diff < best_degree_diff:
                best_degree_diff = degree_diff
                best_G = G
                best_params = {'r_inner': r_inner, 'r_outer': r_outer, 
                            'r_connect': r_connect, 'avg_degree': avg_degree}
        
        # Record parameters used to generate this graph
        G = best_G
        G.graph['params'] = best_params
        
        return G
    
    def generate_random_geometric(self, n_nodes):
        """Generate random geometric graph - nodes distributed in space, connected if distance is less than threshold"""
        # Control connection threshold
        radius = random.uniform(0.15, 0.25)
        
        # Generate RGG
        G = nx.random_geometric_graph(n_nodes, radius)
        
        # Ensure connectivity
        if not nx.is_connected(G):
            components = list(nx.connected_components(G))
            largest_cc = max(components, key=len)
            
            # Connect largest connected component with other components
            for comp in components:
                if comp != largest_cc:
                    node1 = random.choice(list(comp))
                    node2 = random.choice(list(largest_cc))
                    G.add_edge(node1, node2)
        
        return G
    
    def generate_community(self, n_nodes):
        """Generate community structure - multiple relatively independent communities, dense connections within communities, sparse between"""
        G = nx.Graph()
        
        # Choose 3-5 communities
        num_communities = random.randint(3, min(5, n_nodes // 5))
        
        # Calculate approximate nodes per community
        nodes_per_comm = n_nodes // num_communities
        
        # Create communities
        communities = []
        node_counter = 0
        
        for i in range(num_communities):
            # Last community may be slightly larger to handle division remainder
            if i == num_communities - 1:
                comm_size = n_nodes - node_counter
            else:
                comm_size = nodes_per_comm
            
            community = list(range(node_counter, node_counter + comm_size))
            communities.append(community)
            node_counter += comm_size
            
            # Within-community connections - higher probability
            p_within = random.uniform(0.6, 0.8)
            for u in community:
                G.add_node(u)
                for v in community:
                    if u < v and random.random() < p_within:
                        G.add_edge(u, v)
        
        # Between-community connections - sparse
        p_between = random.uniform(0.01, 0.05)
        for i in range(num_communities):
            for j in range(i+1, num_communities):
                for u in communities[i]:
                    for v in communities[j]:
                        if random.random() < p_between:
                            G.add_edge(u, v)
        
        # Ensure graph is connected
        if not nx.is_connected(G):
            # Connect communities
            for i in range(num_communities-1):
                u = random.choice(communities[i])
                v = random.choice(communities[i+1])
                G.add_edge(u, v)
                    
        return G

    def generate_hierarchical_hub(self, n_nodes):
        """Generate hierarchical hub structure - dynamic multi-level organization of central nodes"""
        G = nx.Graph()
        G.add_nodes_from(range(n_nodes))
        
        # Determine number of levels (usually 2-4)
        num_levels = min(4, max(2, n_nodes // 10))
        
        # Calculate approximate node ratio for each level
        # Higher levels have fewer nodes, decreasing geometrically
        level_ratios = []
        remaining_ratio = 1.0
        decay_factor = 0.4  # Controls rate of decrease between levels
        
        for level in range(num_levels-1):  # Last level uses all remaining nodes
            # Top level smallest, increasing downward
            level_ratio = remaining_ratio * decay_factor if level < num_levels-2 else remaining_ratio
            level_ratios.append(level_ratio)
            remaining_ratio -= level_ratio
        
        # Ensure bottom level has at least half the nodes
        if remaining_ratio < 0.5:
            # Readjust ratios
            level_ratios = [r * 0.5 for r in level_ratios]
            remaining_ratio = 0.5
        
        level_ratios.append(remaining_ratio)  # Add bottom level
        
        # Calculate actual nodes per level
        level_sizes = [max(1, int(ratio * n_nodes)) for ratio in level_ratios]
        # Adjust to ensure total is n_nodes
        diff = n_nodes - sum(level_sizes)
        level_sizes[-1] += diff
        
        # Assign nodes to each level
        level_nodes = []
        start_idx = 0
        for size in level_sizes:
            level_nodes.append(list(range(start_idx, start_idx + size)))
            start_idx += size
        
        # Connect nodes across levels
        for level in range(num_levels-1):
            upper_level = level_nodes[level]      # Upper level nodes
            lower_level = level_nodes[level+1]    # Lower level nodes
            
            # 1. Intra-level connections (denser at higher levels)
            intra_density = 0.7 * (num_levels - level) / num_levels  # Upper levels more densely connected
            for i in range(len(upper_level)):
                for j in range(i+1, len(upper_level)):
                    if random.random() < intra_density:
                        G.add_edge(upper_level[i], upper_level[j])
            
            # 2. Inter-level connections
            # Each lower level node connects to several upper level nodes, fewer as level increases
            for lower_node in lower_level:
                connections = min(
                    random.randint(1, 3),  # 1-3 connections
                    len(upper_level)       # Not exceeding total upper level nodes
                )
                upper_connections = random.sample(upper_level, connections)
                for upper_node in upper_connections:
                    G.add_edge(lower_node, upper_node)
        
        return G
    
    def generate_bottleneck(self, n_nodes):
        """Generate bottleneck structure - multiple communities/modules connected through few bottleneck nodes/edges"""
        G = nx.Graph()
        
        # Choose 2-4 communities
        num_communities = random.randint(2, min(4, n_nodes // 7))
        nodes_per_comm = n_nodes // num_communities
        
        # Create communities
        communities = []
        start_idx = 0
        
        for i in range(num_communities):
            if i == num_communities - 1:
                comm_size = n_nodes - start_idx
            else:
                comm_size = nodes_per_comm
                
            community = list(range(start_idx, start_idx + comm_size))
            communities.append(community)
            start_idx += comm_size
            
            # Within-community connections
            p_within = random.uniform(0.4, 0.6)
            for u in community:
                G.add_node(u)
                for v in community:
                    if u < v and random.random() < p_within:
                        G.add_edge(u, v)
        
        # Connect communities through bottlenecks 
        for i in range(len(communities) - 1):
            # Select few nodes (1-3) from each community as bottlenecks
            bottleneck_count = random.randint(1, min(3, min(len(communities[i]), len(communities[i+1]))))
            
            # Select bottleneck nodes from two adjacent communities
            bottlenecks1 = random.sample(communities[i], bottleneck_count)
            bottlenecks2 = random.sample(communities[i+1], bottleneck_count)
            
            # Create connections between bottleneck nodes
            for j in range(bottleneck_count):
                G.add_edge(bottlenecks1[j], bottlenecks2[j])
                
        return G
    
    def generate_multi_core(self, n_nodes):
        """Generate multi-core-periphery structure - multiple relatively independent cores, periphery nodes each connect to a single core"""
        G = nx.Graph()
        
        # Determine number of cores (2-3)
        n_cores = random.randint(2, 3)
        
        # Core nodes proportion 50-60%
        core_ratio = random.uniform(0.5, 0.6)
        core_total_size = int(n_nodes * core_ratio)
        core_size = core_total_size // n_cores
        
        # Ensure each core has at least 4 nodes
        core_size = max(4, core_size)
        
        # Create cores
        cores = []
        node_counter = 0
        
        for i in range(n_cores):
            if i == n_cores - 1:
                core = list(range(node_counter, min(core_total_size, n_nodes - 5)))
            else:
                core = list(range(node_counter, node_counter + core_size))
            
            cores.append(core)
            node_counter += len(core)
            
            # Within-core connections
            p_within = random.uniform(0.6, 0.8)
            for u in core:
                G.add_node(u)
                for v in core:
                    if u < v and random.random() < p_within:
                        G.add_edge(u, v)
        
        # Connect between cores (sparingly)
        for i in range(len(cores)):
            for j in range(i+1, len(cores)):
                # Select 1-2 nodes from each core as bridges
                bridge_count = random.randint(1, 2)
                bridges1 = random.sample(cores[i], bridge_count)
                bridges2 = random.sample(cores[j], bridge_count)
                
                # Create bridge connections
                for k in range(bridge_count):
                    G.add_edge(bridges1[k], bridges2[k])
        
        # Periphery nodes
        periphery = list(range(node_counter, n_nodes))
        
        # Randomly assign each periphery node to a core
        for p in periphery:
            G.add_node(p)
            
            # Randomly select a core
            selected_core = random.choice(cores)
            
            # Connect to 1-2 nodes in that core
            num_connections = random.randint(1, 2)
            connections = random.sample(selected_core, num_connections)
            for connection in connections:
                G.add_edge(p, connection)
        
        return G

def generate_topology_dataset(root_dir, seed=42):
    """Generate topology structure dataset, including three difficulty levels of test sets"""
    set_seed(seed)
    
    # Directory setup
    os.makedirs(root_dir, exist_ok=True)
    
    # Test set difficulty configurations
    difficulty_configs = {
        'ID': {
            'min_nodes': 20, 
            'max_nodes': 50,  # Same distribution as training set
            'samples_per_class': 50
        },
        'Near-OOD': {
            'min_nodes': 40, 
            'max_nodes': 100,  # Medium distribution shift
            'samples_per_class': 50
        },
        'Far-OOD': {
            'min_nodes': 60, 
            'max_nodes': 150,  # Significant distribution shift
            'samples_per_class': 50
        }
    }
    
    # Generate training and validation sets
    train_dataset = TopologyDataset(os.path.join(root_dir, 'train'), mode='train')
    val_dataset = TopologyDataset(os.path.join(root_dir, 'val'), mode='val')
    
    # Generate test sets of all difficulties
    test_datasets = {}
    
    for difficulty, config in difficulty_configs.items():
        test_dir = os.path.join(root_dir, f'test_{difficulty}')
        
        # Check if test set of this difficulty already exists
        processed_file = os.path.join(test_dir, 'processed', 'topology_test.pt')
        if os.path.exists(processed_file):
            print(f"Found existing {difficulty} test set, loading directly...")
            test_dataset = TopologyDataset(test_dir, mode='test', use_degree_features=True)
        else:
            print(f"Generating {difficulty} test set...")
            test_dataset = TopologyDataset(
                test_dir, 
                mode='test',
                min_nodes=config['min_nodes'],
                max_nodes=config['max_nodes'],
                samples_per_class=config['samples_per_class'],
                use_degree_features=True
            )
        
        test_datasets[difficulty] = test_dataset
        
        print(f"{difficulty} test set: {len(test_dataset)} samples, node range: {config['min_nodes']}-{config['max_nodes']}")
    
    # Print training and validation set information
    print(f"Training set: {len(train_dataset)} samples, node range: {train_dataset.min_nodes}-{train_dataset.max_nodes}")
    print(f"Validation set: {len(val_dataset)} samples, node range: {val_dataset.min_nodes}-{val_dataset.max_nodes}")
    
    # Print class distribution for all datasets
    for name, dataset in [('Training set', train_dataset), ('Validation set', val_dataset)] + [(f'{d} test set', ds) for d, ds in test_datasets.items()]:
        class_counts = Counter([data.y.item() for data in dataset])
        print(f"\n{name} class distribution:")
        for class_idx, count in sorted(class_counts.items()):
            print(f"Class {dataset.class_names[class_idx]}: {count}")
    
    return train_dataset, val_dataset, test_datasets

def generate_topology_image_dataset(root_dir, layout="spring", image_size=224, seed=42):
    """Generate image representations for topology structure dataset, including test sets of all difficulties"""
    set_seed(seed)
    
    # Ensure directory exists
    image_dir = os.path.join(root_dir, "images")
    os.makedirs(image_dir, exist_ok=True)
    
    dataset_csv = os.path.join(root_dir, "dataset.csv")
    
    # Load generated PyG datasets
    train_dataset = TopologyDataset(os.path.join(root_dir, 'train'), mode='train')
    val_dataset = TopologyDataset(os.path.join(root_dir, 'val'), mode='val')
    
    # Load test sets of all difficulties
    difficulties = ['ID', 'Near-OOD', 'Far-OOD']
    test_datasets = {}
    
    for difficulty in difficulties:
        test_dir = os.path.join(root_dir, f'test_{difficulty}')
        if os.path.exists(test_dir):
            test_datasets[f'test_{difficulty}'] = TopologyDataset(test_dir, mode='test')
        else:
            print(f"Warning: {difficulty} test set not found, please ensure test sets are generated")
    
    # Combine all datasets
    all_datasets = {
        'train': train_dataset,
        'val': val_dataset,
        **test_datasets  # Add all test sets
    }
    
    # Check if images and information already exist
    if os.path.exists(dataset_csv):
        print("Found existing dataset information.")
        df = pd.read_csv(dataset_csv)
    else:
        print("Dataset information not found, will create new dataset information.")
        df = pd.DataFrame(columns=["image_path", "label", "class_name", "split"])
    
    # Generate images for all datasets
    new_rows = []
    layout_func = getattr(nx, f"{layout}_layout")
    dpi = image_size//8
    
    print("Generating images...")
    global_idx = len(df) if not df.empty else 0
    
    for split, dataset in all_datasets.items():
        # Check if images for this split are already generated
        split_images = df[df['split'] == split]
        if not split_images.empty and len(split_images) == len(dataset):
            # Check if files exist
            all_exist = all(os.path.exists(os.path.join(image_dir, path)) for path in split_images['image_path'])
            if all_exist:
                print(f"Found all images for {split} split, skipping generation.")
                continue
        
        print(f"Generating images for {split} split...")
        
        # Start generating images
        for idx in tqdm(range(len(dataset)), desc=f"Generating {split} images"):
            data = dataset[idx]
            G = nx.Graph()
            G.add_nodes_from(range(data.num_nodes))
            G.add_edges_from(data.edge_index.t().numpy())
            
            # Get layout
            if layout == "spring":
                pos = layout_func(G, seed=seed + global_idx, k=0.3)
            else:
                pos = layout_func(G)
            
            # Draw graph with black background
            plt.figure(figsize=(8, 8), facecolor='black')
            nx.draw_networkx_nodes(G, pos, node_size=50, node_color='skyblue', 
                                  edgecolors='white', linewidths=0.8, alpha=0.9)
            nx.draw_networkx_edges(G, pos, width=1.5, alpha=0.8, edge_color='white')
            
            plt.axis('off')
            plt.tight_layout()
            
            # Generate unique filename
            img_path = f"{split}_graph_{global_idx:04d}.png"
            full_path = os.path.join(image_dir, img_path)
            plt.savefig(full_path, dpi=dpi, facecolor='black', bbox_inches='tight')
            plt.close()
            
            new_rows.append({
                "image_path": img_path,
                "label":      data.y.item(),
                "class_name": dataset.class_names[data.y.item()],
                "split":      split,
                "graph_idx":  idx          # local index in that split
            })
            
            global_idx += 1
    
    # Update dataset information
    if new_rows:
        new_df = pd.DataFrame(new_rows)
        if df.empty:
            df = new_df
        else:
            df = pd.concat([df, new_df], ignore_index=True)
            
        df.to_csv(dataset_csv, index=False)
    
    print(f"Dataset image generation complete! Images saved in {image_dir}")
    print(f"Dataset information saved in {dataset_csv}")
    
    # Print current split information in CSV
    split_counts = df['split'].value_counts()
    print("\nDataset split distribution:")
    for split, count in split_counts.items():
        print(f"{split}: {count} samples")
    
    return image_dir, dataset_csv

# Main function entry
if __name__ == "__main__":
    # Set root directory
    root_dir = "./topology_dataset"
    
    # Generate dataset
    train_dataset, val_dataset, test_dataset = generate_topology_dataset(root_dir)
    
    
    # Generate image dataset
    image_dir, dataset_csv = generate_topology_image_dataset(
        root_dir=root_dir,
        layout="spring",
        image_size=224,
        seed=42
    )