import torch
import networkx as nx
from typing import Dict, List, Tuple, Set, Optional
from ..representation import MolecularGraph


class TopologyAnalyzer:
    def __init__(self):
        self.cached_features = {}

    def analyze_topology(self, graph: MolecularGraph) -> Dict[str, float]:
        graph_id = id(graph)
        if graph_id in self.cached_features:
            return self.cached_features[graph_id]

        features = self._compute_topology_features(graph)
        self.cached_features[graph_id] = features
        return features

    def _compute_topology_features(self, graph: MolecularGraph) -> Dict[str, float]:
        nx_graph = graph.graph

        if len(nx_graph) == 0:
            return self._empty_graph_features()

        features = {
            'num_nodes': len(nx_graph.nodes),
            'num_edges': len(nx_graph.edges),
            'density': nx.density(nx_graph),
            'num_connected_components': nx.number_connected_components(nx_graph),
            'is_connected': float(nx.is_connected(nx_graph)),
            'diameter': self._safe_diameter(nx_graph),
            'radius': self._safe_radius(nx_graph),
            'average_clustering': nx.average_clustering(nx_graph),
            'transitivity': nx.transitivity(nx_graph),
            'average_degree': sum(dict(nx_graph.degree()).values()) / len(nx_graph.nodes),
            'degree_centrality_mean': self._mean_centrality(nx.degree_centrality(nx_graph)),
            'betweenness_centrality_mean': self._mean_centrality(nx.betweenness_centrality(nx_graph)),
            'closeness_centrality_mean': self._mean_centrality(nx.closeness_centrality(nx_graph)),
            'num_cycles': len(nx.cycle_basis(nx_graph)),
            'edge_connectivity': nx.edge_connectivity(nx_graph),
            'node_connectivity': nx.node_connectivity(nx_graph)
        }

        features.update(self._compute_degree_distribution(nx_graph))
        features.update(self._compute_path_statistics(nx_graph))

        return features

    def _empty_graph_features(self) -> Dict[str, float]:
        return {key: 0.0 for key in [
            'num_nodes', 'num_edges', 'density', 'num_connected_components',
            'is_connected', 'diameter', 'radius', 'average_clustering',
            'transitivity', 'average_degree', 'degree_centrality_mean',
            'betweenness_centrality_mean', 'closeness_centrality_mean',
            'num_cycles', 'edge_connectivity', 'node_connectivity',
            'degree_std', 'max_degree', 'min_degree',
            'avg_shortest_path', 'path_length_std'
        ]}

    def _safe_diameter(self, graph: nx.Graph) -> float:
        if nx.is_connected(graph):
            try:
                return float(nx.diameter(graph))
            except:
                return 0.0
        return float('inf')

    def _safe_radius(self, graph: nx.Graph) -> float:
        if nx.is_connected(graph):
            try:
                return float(nx.radius(graph))
            except:
                return 0.0
        return float('inf')

    def _mean_centrality(self, centrality_dict: Dict) -> float:
        if not centrality_dict:
            return 0.0
        return sum(centrality_dict.values()) / len(centrality_dict)

    def _compute_degree_distribution(self, graph: nx.Graph) -> Dict[str, float]:
        degrees = list(dict(graph.degree()).values())
        if not degrees:
            return {'degree_std': 0.0, 'max_degree': 0.0, 'min_degree': 0.0}

        degree_tensor = torch.tensor(degrees, dtype=torch.float32)
        return {
            'degree_std': degree_tensor.std().item(),
            'max_degree': float(max(degrees)),
            'min_degree': float(min(degrees))
        }

    def _compute_path_statistics(self, graph: nx.Graph) -> Dict[str, float]:
        if not nx.is_connected(graph) or len(graph.nodes) < 2:
            return {'avg_shortest_path': float('inf'), 'path_length_std': 0.0}

        try:
            path_lengths = []
            for source in graph.nodes:
                paths = nx.single_source_shortest_path_length(graph, source)
                path_lengths.extend(paths.values())

            if path_lengths:
                path_tensor = torch.tensor(path_lengths, dtype=torch.float32)
                return {
                    'avg_shortest_path': path_tensor.mean().item(),
                    'path_length_std': path_tensor.std().item()
                }
        except:
            pass

        return {'avg_shortest_path': float('inf'), 'path_length_std': 0.0}

    def compute_graph_edit_distance(self, graph1: MolecularGraph, graph2: MolecularGraph) -> float:
        try:
            return nx.graph_edit_distance(graph1.graph, graph2.graph, timeout=10)
        except:
            return float('inf')

    def find_critical_edges(self, graph: MolecularGraph) -> List[Tuple[str, str]]:
        nx_graph = graph.graph
        critical_edges = []

        for edge in nx_graph.edges():
            temp_graph = nx_graph.copy()
            temp_graph.remove_edge(*edge)

            if nx.number_connected_components(temp_graph) > nx.number_connected_components(nx_graph):
                critical_edges.append(edge)

        return critical_edges

    def find_bridges(self, graph: MolecularGraph) -> List[Tuple[str, str]]:
        return list(nx.bridges(graph.graph))

    def compute_betweenness_edges(self, graph: MolecularGraph) -> Dict[Tuple[str, str], float]:
        return nx.edge_betweenness_centrality(graph.graph)

    def get_motif_importance_scores(self, graph: MolecularGraph) -> Dict[str, float]:
        centrality_scores = {}
        nx_graph = graph.graph

        degree_centrality = nx.degree_centrality(nx_graph)
        betweenness_centrality = nx.betweenness_centrality(nx_graph)
        closeness_centrality = nx.closeness_centrality(nx_graph)

        for node in nx_graph.nodes:
            score = (
                0.3 * degree_centrality.get(node, 0) +
                0.4 * betweenness_centrality.get(node, 0) +
                0.3 * closeness_centrality.get(node, 0)
            )
            centrality_scores[node] = score

        return centrality_scores

    def compute_assembly_progress(self, current_graph: MolecularGraph,
                                target_graph: MolecularGraph) -> Dict[str, float]:
        current_edges = set(current_graph.graph.edges())
        target_edges = set(target_graph.graph.edges())

        correctly_formed = len(current_edges & target_edges)
        total_target = len(target_edges)
        over_connected = len(current_edges - target_edges)

        progress = {
            'edge_completion_ratio': correctly_formed / max(total_target, 1),
            'connectivity_progress': 1.0 - (current_graph.num_connected_components() - 1) / max(len(current_graph.motifs) - 1, 1),
            'over_connection_penalty': over_connected / max(total_target, 1),
            'topology_similarity': 1.0 / (1.0 + self.compute_graph_edit_distance(current_graph, target_graph))
        }

        return progress