import torch
import numpy as np
import networkx as nx
from torch_geometric.data import Data, DataLoader
import random
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Global parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def sample_critical_pairs(G, num_pairs=10, min_distance=2):
    """Sample critical source-target pairs from a graph"""
    nodes = list(G.nodes())
    pairs = []
    
    attempts = 0
    max_attempts = num_pairs * 10  # Avoid infinite loop
    
    while len(pairs) < num_pairs and attempts < max_attempts:
        attempts += 1
        s, t = random.sample(nodes, 2)
        
        # Skip if already sampled
        if (s, t) in pairs:
            continue
        
        # Check if there's a path between s and t
        try:
            path = nx.shortest_path(G, s, t)
            if len(path) >= min_distance:
                pairs.append((s, t))  # Ensure we're returning a tuple
        except nx.NetworkXNoPath:
            continue
    
    if len(pairs) < num_pairs:
        print(f"Warning: Could only find {len(pairs)} valid pairs out of {num_pairs} requested")
    
    return pairs

def prepare_graph_data(G, node_features=None, edge_weights=None):
    """
    Convert a NetworkX graph to PyTorch Geometric Data object
    """
    num_nodes = G.number_of_nodes()
    
    # Default node features (degree centrality if not provided)
    if node_features is None:
        if nx.is_directed(G):
            in_degree = dict(G.in_degree())
            out_degree = dict(G.out_degree())
            degree_centrality = dict(nx.degree_centrality(G))
            node_features = {
                node: [in_degree.get(node, 0), out_degree.get(node, 0), degree_centrality.get(node, 0)]
                for node in G.nodes()
            }
        else:
            degree_dict = dict(G.degree())
            centrality_dict = dict(nx.degree_centrality(G))
            node_features = {
                node: [degree_dict.get(node, 0), centrality_dict.get(node, 0), 0]
                for node in G.nodes()
            }
    
    # Default edge weights (1.0 if not provided)
    if edge_weights is None:
        edge_weights = {(u, v): 1.0 for u, v in G.edges()}
    
    # Node mapping to consecutive integers
    node_mapping = {node: i for i, node in enumerate(G.nodes())}
    
    # Prepare node features and edge data
    x = torch.zeros((num_nodes, 3), dtype=torch.float)
    for node, idx in node_mapping.items():
        x[idx] = torch.tensor(node_features[node], dtype=torch.float)
    
    # Prepare edge indices and attributes
    edge_index = []
    edge_attr = []
    
    for u, v in G.edges():
        edge_index.append([node_mapping[u], node_mapping[v]])
        
        # Get edge weight
        if isinstance(edge_weights, dict):
            weight = edge_weights.get((u, v), 1.0)
        else:
            weight = 1.0
            
        edge_attr.append([weight])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    # Create Data object
    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr
    )
    
    return data, node_mapping




def encode_graph_condition(G, K, T, node_embedding_dim=64, max_pairs=20):
    """
    Encode the graph structure, critical pairs, and threshold as condition vector
    
    Args:
        G: NetworkX graph
        K: List of (source, target) pairs
        T: Threshold
        node_embedding_dim: Dimension for node embeddings
        max_pairs: Maximum number of pairs to encode
    
    Returns:
        condition_vector: Tensor encoding of the graph and problem instance
    """
    # Graph structure features
    num_nodes = G.number_of_nodes()
    num_edges = G.number_of_edges()
    density = num_edges / (num_nodes * (num_nodes - 1))
    
    # Node centrality features
    in_degree_centrality = list(nx.in_degree_centrality(G).values()) if nx.is_directed(G) else [0]
    out_degree_centrality = list(nx.out_degree_centrality(G).values()) if nx.is_directed(G) else [0]
    avg_in_centrality = sum(in_degree_centrality) / len(in_degree_centrality) if in_degree_centrality else 0
    avg_out_centrality = sum(out_degree_centrality) / len(out_degree_centrality) if out_degree_centrality else 0
    
    # Graph-level features
    graph_features = [
        num_nodes, num_edges, density, 
        avg_in_centrality, avg_out_centrality,
        float(nx.is_directed(G))
    ]
    
    # Encode critical pairs (limited to max_pairs)
    pair_features = []
    for i, (s, t) in enumerate(K[:max_pairs]):
        try:
            path = nx.shortest_path(G, s, t)
            path_length = len(path) - 1
            path_exists = 1.0
        except nx.NetworkXNoPath:
            path_length = 0
            path_exists = 0.0
        
        s_in_degree = G.in_degree(s) if nx.is_directed(G) else G.degree(s)
        s_out_degree = G.out_degree(s) if nx.is_directed(G) else G.degree(s)
        t_in_degree = G.in_degree(t) if nx.is_directed(G) else G.degree(t)
        t_out_degree = G.out_degree(t) if nx.is_directed(G) else G.degree(t)
        
        pair_features.extend([
            s_in_degree, s_out_degree,
            t_in_degree, t_out_degree,
            path_length, path_exists
        ])
    
    # Pad if needed
    pair_feature_len = 6  # Number of features per pair
    padding = [0.0] * (max_pairs * pair_feature_len - len(pair_features))
    pair_features.extend(padding)
   
    # Threshold and its relation to average path length
    avg_path_length = 0
    try:
        avg_path_length = nx.average_shortest_path_length(G)
    except:
        # Handle disconnected graphs
        avg_path_length = 5.0  # Default value
    
    threshold_features = [
        T,
        T / avg_path_length if avg_path_length > 0 else 1.0
    ]
    
    # Combine all features
    condition = graph_features + pair_features + threshold_features
    
    return torch.tensor(condition, dtype=torch.float, device=device)

    



def generate_training_data(G, num_samples=1000, batch_size=32):
    """
    Generate training and validation data for SPAGAN from a graph
    """
    nodes = list(G.nodes())
    pairs = []
    
    # Generate source-target pairs with their shortest path distances
    print(f"Generating {num_samples} source-target pairs for training...")
    
    pbar = tqdm(total=num_samples)
    while len(pairs) < num_samples:
        # Sample source and target nodes
        source, target = np.random.choice(nodes, 2, replace=False)
        
        # Skip if already sampled
        if (source, target) in pairs:
            continue
        
        # Calculate shortest path
        try:
            path = nx.shortest_path(G, source=source, target=target, weight='weight')
            path_cost = sum(G[path[i]][path[i+1]]['weight'] for i in range(len(path)-1))
            
            # Only add pairs with valid paths
            pairs.append((source, target, path_cost))
            pbar.update(1)
        except nx.NetworkXNoPath:
            continue
    
    pbar.close()
    
    # Convert to PyG Data objects
    data_objects = []
    node_mapping = {node: i for i, node in enumerate(nodes)}
    
    # Prepare node features
    in_degree = dict(G.in_degree())
    out_degree = dict(G.out_degree())
    degree_centrality = dict(nx.degree_centrality(G))
    
    node_features = torch.zeros((len(nodes), 3), dtype=torch.float)
    for node, idx in node_mapping.items():
        node_features[idx] = torch.tensor([
            in_degree.get(node, 0),
            out_degree.get(node, 0),
            degree_centrality.get(node, 0)
        ], dtype=torch.float)
    
    # Prepare edge data
    edge_index = []
    edge_attr = []
    
    for u, v in G.edges():
        edge_index.append([node_mapping[u], node_mapping[v]])
        edge_attr.append([G[u][v]['weight']])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    # Create a single data object for the graph
    source_indices = torch.tensor([node_mapping[s] for s, _, _ in pairs], dtype=torch.long)
    target_indices = torch.tensor([node_mapping[t] for _, t, _ in pairs], dtype=torch.long)
    true_distances = torch.tensor([d for _, _, d in pairs], dtype=torch.float)
    
    # Create batches
    num_batches = (num_samples + batch_size - 1) // batch_size
    data_list = []
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, num_samples)
        
        batch_source = source_indices[start_idx:end_idx]
        batch_target = target_indices[start_idx:end_idx]
        batch_distances = true_distances[start_idx:end_idx]
        
        data = Data(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_attr,
            source_indices=batch_source,
            target_indices=batch_target,
            true_distances=batch_distances,
            perturbation=torch.zeros(edge_attr.size(0))
        )
        
        data_list.append(data)
    
    # Split into training and validation sets (80/20)
    split_idx = int(0.8 * len(data_list))
    train_data = data_list[:split_idx]
    val_data = data_list[split_idx:]
    
    print(f"Created {len(train_data)} training batches and {len(val_data)} validation batches")
    
    train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=1, shuffle=False)
    
    return train_loader, val_loader, node_mapping
