import torch
import networkx as nx
from typing import List, Dict, Set, Tuple, Optional
from .motif import Motif, Connection


class MolecularGraph:
    def __init__(self, motifs: List[Motif]):
        self.motifs = {motif.motif_id: motif for motif in motifs}
        self.connections: List[Connection] = []
        self.adjacency_matrix = torch.zeros((len(motifs), len(motifs)))
        self.edge_types = {}
        self.graph = nx.Graph()
        self._initialize_graph()

    def _initialize_graph(self):
        for motif_id in self.motifs:
            self.graph.add_node(motif_id)

    def add_connection(self, connection: Connection):
        if (connection.source_motif in self.motifs and
            connection.target_motif in self.motifs):

            self.connections.append(connection)

            src_idx = list(self.motifs.keys()).index(connection.source_motif)
            tgt_idx = list(self.motifs.keys()).index(connection.target_motif)

            self.adjacency_matrix[src_idx, tgt_idx] = 1
            self.adjacency_matrix[tgt_idx, src_idx] = 1

            edge_key = (connection.source_motif, connection.target_motif)
            self.edge_types[edge_key] = connection.bond_type

            self.graph.add_edge(connection.source_motif, connection.target_motif,
                              bond_type=connection.bond_type)

    def remove_connection(self, source_motif: str, target_motif: str):
        self.connections = [c for c in self.connections
                           if not (c.source_motif == source_motif and
                                  c.target_motif == target_motif)]

        if self.graph.has_edge(source_motif, target_motif):
            self.graph.remove_edge(source_motif, target_motif)

        src_idx = list(self.motifs.keys()).index(source_motif)
        tgt_idx = list(self.motifs.keys()).index(target_motif)

        self.adjacency_matrix[src_idx, tgt_idx] = 0
        self.adjacency_matrix[tgt_idx, src_idx] = 0

    def get_connected_components(self) -> List[Set[str]]:
        return [set(component) for component in nx.connected_components(self.graph)]

    def num_connected_components(self) -> int:
        return nx.number_connected_components(self.graph)

    def is_connected(self) -> bool:
        return nx.is_connected(self.graph)

    def get_shortest_path_length(self, source: str, target: str) -> Optional[int]:
        try:
            return nx.shortest_path_length(self.graph, source, target)
        except nx.NetworkXNoPath:
            return None

    def get_graph_edit_distance(self, target_graph: 'MolecularGraph') -> float:
        return nx.graph_edit_distance(self.graph, target_graph.graph)

    def to_text_representation(self) -> str:
        text_parts = [
            f"MOLECULAR_GRAPH:",
            f"MOTIFS: {len(self.motifs)}",
            f"CONNECTIONS: {len(self.connections)}",
            f"CONNECTED_COMPONENTS: {self.num_connected_components()}",
            "",
            "MOTIF_DETAILS:"
        ]

        for motif in self.motifs.values():
            text_parts.append(motif.to_text_representation())
            text_parts.append("")

        text_parts.append("CONNECTIONS:")
        for connection in self.connections:
            text_parts.append(connection.to_text_representation())

        return "\n".join(text_parts)

    def get_topology_features(self) -> Dict[str, float]:
        return {
            'num_motifs': len(self.motifs),
            'num_connections': len(self.connections),
            'num_components': self.num_connected_components(),
            'is_connected': float(self.is_connected()),
            'avg_degree': sum(dict(self.graph.degree()).values()) / len(self.motifs) if self.motifs else 0,
            'clustering_coefficient': nx.average_clustering(self.graph) if len(self.graph) > 0 else 0,
        }

    def get_available_connection_sites(self) -> List[Tuple[str, int]]:
        available_sites = []
        for motif_id, motif in self.motifs.items():
            for site in motif.get_available_sites():
                available_sites.append((motif_id, site.site_id))
        return available_sites