import networkx as nx
import numpy as np
from typing import Dict, List, Tuple, Set, Optional, Any
import torch
from ...core.representation import MolecularGraph


class GraphUtils:
    @staticmethod
    def molecular_graph_to_networkx(mol_graph: MolecularGraph) -> nx.Graph:
        return mol_graph.graph.copy()

    @staticmethod
    def networkx_to_adjacency_matrix(graph: nx.Graph) -> Tuple[torch.Tensor, List[str]]:
        nodes = list(graph.nodes())
        num_nodes = len(nodes)

        adjacency = torch.zeros((num_nodes, num_nodes), dtype=torch.float32)

        node_to_idx = {node: idx for idx, node in enumerate(nodes)}

        for edge in graph.edges():
            i = node_to_idx[edge[0]]
            j = node_to_idx[edge[1]]
            adjacency[i, j] = 1.0
            adjacency[j, i] = 1.0

        return adjacency, nodes

    @staticmethod
    def calculate_graph_statistics(graph: nx.Graph) -> Dict[str, Any]:
        if len(graph) == 0:
            return {'empty_graph': True}

        stats = {
            'num_nodes': graph.number_of_nodes(),
            'num_edges': graph.number_of_edges(),
            'density': nx.density(graph),
            'is_connected': nx.is_connected(graph),
            'num_connected_components': nx.number_connected_components(graph),
        }

        if nx.is_connected(graph):
            stats.update({
                'diameter': nx.diameter(graph),
                'radius': nx.radius(graph),
                'average_shortest_path_length': nx.average_shortest_path_length(graph),
            })
        else:
            stats.update({
                'diameter': float('inf'),
                'radius': float('inf'),
                'average_shortest_path_length': float('inf'),
            })

        # Degree statistics
        degrees = list(dict(graph.degree()).values())
        if degrees:
            stats.update({
                'average_degree': np.mean(degrees),
                'degree_std': np.std(degrees),
                'max_degree': max(degrees),
                'min_degree': min(degrees),
            })

        # Centrality measures
        try:
            degree_centrality = nx.degree_centrality(graph)
            betweenness_centrality = nx.betweenness_centrality(graph)
            closeness_centrality = nx.closeness_centrality(graph)

            stats.update({
                'avg_degree_centrality': np.mean(list(degree_centrality.values())),
                'avg_betweenness_centrality': np.mean(list(betweenness_centrality.values())),
                'avg_closeness_centrality': np.mean(list(closeness_centrality.values())),
            })
        except Exception:
            stats.update({
                'avg_degree_centrality': 0.0,
                'avg_betweenness_centrality': 0.0,
                'avg_closeness_centrality': 0.0,
            })

        # Clustering
        stats['average_clustering'] = nx.average_clustering(graph)
        stats['transitivity'] = nx.transitivity(graph)

        # Connectivity
        stats['edge_connectivity'] = nx.edge_connectivity(graph)
        stats['node_connectivity'] = nx.node_connectivity(graph)

        return stats

    @staticmethod
    def find_graph_motifs(graph: nx.Graph, motif_size: int = 3) -> List[Set]:
        motifs = []

        if motif_size == 3:
            # Find triangles
            triangles = [set(triangle) for triangle in nx.enumerate_all_cliques(graph) if len(triangle) == 3]
            motifs.extend(triangles)

        elif motif_size == 4:
            # Find 4-cliques and 4-cycles
            four_cliques = [set(clique) for clique in nx.enumerate_all_cliques(graph) if len(clique) == 4]
            motifs.extend(four_cliques)

        return motifs

    @staticmethod
    def calculate_graph_similarity(graph1: nx.Graph, graph2: nx.Graph, method: str = 'edit_distance') -> float:
        if method == 'edit_distance':
            try:
                distance = nx.graph_edit_distance(graph1, graph2, timeout=10)
                if distance == float('inf'):
                    return 0.0
                # Convert to similarity (0 to 1)
                max_nodes = max(len(graph1), len(graph2), 1)
                normalized_distance = distance / max_nodes
                return max(0, 1 - normalized_distance)
            except Exception:
                return 0.0

        elif method == 'jaccard':
            edges1 = set(graph1.edges())
            edges2 = set(graph2.edges())

            # Normalize edge tuples
            edges1 = {tuple(sorted(edge)) for edge in edges1}
            edges2 = {tuple(sorted(edge)) for edge in edges2}

            intersection = len(edges1 & edges2)
            union = len(edges1 | edges2)

            return intersection / union if union > 0 else 1.0

        else:
            raise ValueError(f"Unknown similarity method: {method}")

    @staticmethod
    def find_critical_nodes(graph: nx.Graph) -> List[Any]:
        critical_nodes = []

        for node in graph.nodes():
            # Create graph without this node
            temp_graph = graph.copy()
            temp_graph.remove_node(node)

            # Check if removing this node increases connected components
            if nx.number_connected_components(temp_graph) > nx.number_connected_components(graph):
                critical_nodes.append(node)

        return critical_nodes

    @staticmethod
    def find_critical_edges(graph: nx.Graph) -> List[Tuple]:
        critical_edges = []

        for edge in graph.edges():
            # Create graph without this edge
            temp_graph = graph.copy()
            temp_graph.remove_edge(*edge)

            # Check if removing this edge increases connected components
            if nx.number_connected_components(temp_graph) > nx.number_connected_components(graph):
                critical_edges.append(edge)

        return critical_edges

    @staticmethod
    def generate_graph_embedding(graph: nx.Graph, method: str = 'adjacency', dim: int = 64) -> Optional[torch.Tensor]:
        if len(graph) == 0:
            return torch.zeros(dim)

        try:
            if method == 'adjacency':
                # Simple adjacency-based embedding
                adj_matrix, nodes = GraphUtils.networkx_to_adjacency_matrix(graph)
                # Use eigendecomposition for embedding
                eigenvalues, eigenvectors = torch.linalg.eig(adj_matrix.float())
                # Take real parts and first 'dim' components
                embedding = eigenvectors.real[:, :min(dim, eigenvectors.size(1))].flatten()

                # Pad or truncate to desired dimension
                if len(embedding) < dim:
                    embedding = torch.cat([embedding, torch.zeros(dim - len(embedding))])
                else:
                    embedding = embedding[:dim]

                return embedding

            elif method == 'laplacian':
                # Laplacian-based embedding
                laplacian = nx.laplacian_matrix(graph).toarray()
                laplacian_tensor = torch.tensor(laplacian, dtype=torch.float32)

                eigenvalues, eigenvectors = torch.linalg.eig(laplacian_tensor)
                embedding = eigenvectors.real[:, :min(dim, eigenvectors.size(1))].flatten()

                if len(embedding) < dim:
                    embedding = torch.cat([embedding, torch.zeros(dim - len(embedding))])
                else:
                    embedding = embedding[:dim]

                return embedding

            else:
                raise ValueError(f"Unknown embedding method: {method}")

        except Exception:
            # Return random embedding as fallback
            return torch.randn(dim)

    @staticmethod
    def detect_communities(graph: nx.Graph, method: str = 'louvain') -> Dict[Any, int]:
        if method == 'louvain':
            try:
                import community as community_louvain
                partition = community_louvain.best_partition(graph)
                return partition
            except ImportError:
                # Fallback to simple connected components
                communities = {}
                for i, component in enumerate(nx.connected_components(graph)):
                    for node in component:
                        communities[node] = i
                return communities

        elif method == 'components':
            communities = {}
            for i, component in enumerate(nx.connected_components(graph)):
                for node in component:
                    communities[node] = i
            return communities

        else:
            raise ValueError(f"Unknown community detection method: {method}")

    @staticmethod
    def calculate_node_importance(graph: nx.Graph, method: str = 'degree') -> Dict[Any, float]:
        if method == 'degree':
            return nx.degree_centrality(graph)

        elif method == 'betweenness':
            return nx.betweenness_centrality(graph)

        elif method == 'closeness':
            return nx.closeness_centrality(graph)

        elif method == 'eigenvector':
            try:
                return nx.eigenvector_centrality(graph, max_iter=1000)
            except Exception:
                # Fallback to degree centrality
                return nx.degree_centrality(graph)

        elif method == 'pagerank':
            return nx.pagerank(graph)

        else:
            raise ValueError(f"Unknown importance method: {method}")

    @staticmethod
    def find_shortest_paths(graph: nx.Graph, source: Any, target: Optional[Any] = None) -> Dict:
        if target is None:
            # All shortest paths from source
            try:
                paths = nx.single_source_shortest_path_length(graph, source)
                return dict(paths)
            except nx.NetworkXNoPath:
                return {}
        else:
            # Shortest path between source and target
            try:
                path_length = nx.shortest_path_length(graph, source, target)
                path = nx.shortest_path(graph, source, target)
                return {'length': path_length, 'path': path}
            except nx.NetworkXNoPath:
                return {'length': float('inf'), 'path': None}

    @staticmethod
    def is_tree(graph: nx.Graph) -> bool:
        return nx.is_tree(graph)

    @staticmethod
    def is_forest(graph: nx.Graph) -> bool:
        return nx.is_forest(graph)

    @staticmethod
    def get_spanning_tree(graph: nx.Graph, method: str = 'minimum') -> nx.Graph:
        if not nx.is_connected(graph):
            # Return spanning forest
            spanning_tree = nx.Graph()
            for component in nx.connected_components(graph):
                subgraph = graph.subgraph(component)
                if method == 'minimum':
                    tree = nx.minimum_spanning_tree(subgraph)
                else:
                    tree = nx.maximum_spanning_tree(subgraph)
                spanning_tree = nx.union(spanning_tree, tree)
            return spanning_tree
        else:
            if method == 'minimum':
                return nx.minimum_spanning_tree(graph)
            else:
                return nx.maximum_spanning_tree(graph)

    @staticmethod
    def visualize_graph_stats(graph: nx.Graph) -> str:
        stats = GraphUtils.calculate_graph_statistics(graph)

        summary = [
            f"Graph Statistics:",
            f"  Nodes: {stats.get('num_nodes', 0)}",
            f"  Edges: {stats.get('num_edges', 0)}",
            f"  Density: {stats.get('density', 0):.3f}",
            f"  Connected: {stats.get('is_connected', False)}",
            f"  Components: {stats.get('num_connected_components', 0)}",
            f"  Average Degree: {stats.get('average_degree', 0):.2f}",
            f"  Clustering: {stats.get('average_clustering', 0):.3f}",
        ]

        if stats.get('is_connected', False):
            summary.extend([
                f"  Diameter: {stats.get('diameter', 'inf')}",
                f"  Avg Path Length: {stats.get('average_shortest_path_length', 'inf'):.2f}",
            ])

        return "\n".join(summary)