import torch
import networkx as nx
import numpy as np
import os
from torch_geometric.utils import to_networkx, from_networkx
from sklearn.metrics import normalized_mutual_info_score as NMI
from tqdm import tqdm

def modify_graph(data, dataset_name, G, communities, budget_edges_add, budget_edges_delete):
    """
    Rewire the graph based on feature similarity to maximize feature similarity between connected nodes.
    
    Args:
        data: PyG data object
        dataset_name: Name of the dataset
        G: NetworkX graph
        communities: List of communities in the graph
        budget_edges_add: Number of edges to add
        budget_edges_delete: Number of edges to delete
        
    Returns:
        Modified PyG data object
    """
    # Extract dataset name from data object if it's a path
    dataset_name, _ = os.path.splitext(dataset_name)
    print("=============================================================")
    print("Rewiring based on feature similarity...")
    print(f"Budget for adding edges: {budget_edges_add}")
    print(f"Budget for deleting edges: {budget_edges_delete}")
    
    original_edge_count = G.number_of_edges()
    
    # Store original edge index for later use
    original_edge_index = data.edge_index.clone()

    # Move data to GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #data = data.to(device)

    # Get node features
    features = data.x

    # Calculate cosine similarity between all pairs of nodes using GPU
    similarities = torch.mm(features, features.t())
    norms = torch.norm(features, dim=1, keepdim=True)
    similarities = similarities / (norms * norms.t())
    similarities = similarities.cpu().numpy()  # Move back to CPU for NetworkX operations

    if np.isnan(similarities).any():
        print("Warning: NaN detected in similarities matrix. Using fallback method.")
        # Set NaN values to 0
        similarities = np.nan_to_num(similarities, nan=0.0)

    # Collect all node pairs and their similarity scores
    all_node_pairs = []
    n_nodes = data.num_nodes
    
    # For edge addition: consider non-edge pairs
    existing_edges_set = set((u, v) for u, v in zip(data.edge_index[0].tolist(), data.edge_index[1].tolist()))
    
    for i in range(n_nodes):
        for j in range(n_nodes):
            if i != j and (i, j) not in existing_edges_set:
                all_node_pairs.append((i, j, similarities[i, j]))
    
    # Sort by similarity score (descending)
    all_node_pairs.sort(key=lambda x: x[2], reverse=True)
    
    # Add edges with highest feature similarity
    new_edges = []
    added_edges_set = set()
    edges_added = 0
    
    for i, j, sim in all_node_pairs:
        if edges_added >= budget_edges_add:
            break
        if (i, j) not in existing_edges_set and (i, j) not in added_edges_set:
            new_edges.extend([[i, j]])
            added_edges_set.add((i, j))
            edges_added += 1
    
    # For edge removal: consider existing edges
    edge_scores = [(u.item(), v.item(), similarities[u.item(), v.item()]) 
                  for u, v in zip(data.edge_index[0], data.edge_index[1])]
    
    # Sort by similarity score (ascending)
    edge_scores.sort(key=lambda x: x[2])
    
    # Remove edges with lowest feature similarity
    edges_to_keep = set()
    edges_removed = 0
    
    for u, v, sim in edge_scores:
        if edges_removed >= budget_edges_delete:
            edges_to_keep.add((u, v))
        else:
            edges_removed += 1
        
    # Create new edge_index tensor
    original_edges = [(u.item(), v.item()) for u, v in zip(original_edge_index[0], original_edge_index[1])]
    final_edges = [edge for edge in original_edges if edge in edges_to_keep]
    final_edges.extend(new_edges)
    
    # Convert edges to tensor
    edge_index = torch.tensor(final_edges, dtype=torch.long).t()
    
    # Create new data object
    # Handle both PyG Data objects and our custom Data class
    if hasattr(data, 'clone'):
        data_rewired = data.clone()
    else:
        # For our custom Data class, create a new object with the same attributes
        data_rewired = type('Data', (), {})()
        for attr in ['x', 'y', 'train_mask', 'val_mask', 'test_mask', 'num_nodes', 'homophily']:
            if hasattr(data, attr):
                setattr(data_rewired, attr, getattr(data, attr))
    
    # Set the new edge_index
    data_rewired.edge_index = edge_index.to(device)
    
    # Recalculate homophily for the rewired graph
    edges = edge_index.t().cpu().numpy()
    same_label = data.y.cpu().numpy()[edges[:, 0]] == data.y.cpu().numpy()[edges[:, 1]]
    data_rewired.homophily = float(same_label.mean())
    
    print(f"Added {edges_added} edges with high feature similarity")
    print(f"Removed {edges_removed} edges with low feature similarity")
    print(f"Original edges: {original_edge_count}")
    print(f"New edges: {edge_index.size(1)}")
    print(f"Original homophily: {data.homophily:.4f}")
    print(f"New homophily: {data_rewired.homophily:.4f}")
    
    return data_rewired
