import networkx as nx
from typing import Dict, List, Set, Tuple, Optional
from ...core.representation import MolecularGraph
from ..actions import AssemblyAction
from ..state import AssemblyState


class TopologyConstraints:
    def __init__(self, max_components: int = 1, allow_cycles: bool = True,
                 max_degree: int = 4, min_connectivity: float = 1.0):
        self.max_components = max_components
        self.allow_cycles = allow_cycles
        self.max_degree = max_degree
        self.min_connectivity = min_connectivity

    def validate_topology(self, state: AssemblyState, action: AssemblyAction) -> Tuple[bool, str]:
        if action.is_stop_action():
            return self._validate_final_topology(state)

        return self._validate_connection_topology(state, action)

    def _validate_final_topology(self, state: AssemblyState) -> Tuple[bool, str]:
        graph = state.current_graph

        if state.mode == "reconstruction":
            if state.target_graph:
                return self._validate_reconstruction_topology(graph, state.target_graph)
            else:
                return self._validate_connectivity(graph)
        else:  # generation mode
            return self._validate_generation_topology(graph)

    def _validate_reconstruction_topology(self, current: MolecularGraph,
                                        target: MolecularGraph) -> Tuple[bool, str]:
        current_edges = set(current.graph.edges())
        target_edges = set(target.graph.edges())

        if current_edges != target_edges:
            missing_edges = target_edges - current_edges
            extra_edges = current_edges - target_edges

            if missing_edges:
                return False, f"Missing edges: {list(missing_edges)[:3]}..."
            if extra_edges:
                return False, f"Extra edges: {list(extra_edges)[:3]}..."

        if not nx.is_connected(current.graph):
            return False, "Graph is not fully connected"

        return True, ""

    def _validate_connectivity(self, graph: MolecularGraph) -> Tuple[bool, str]:
        if not nx.is_connected(graph.graph):
            num_components = nx.number_connected_components(graph.graph)
            if num_components > self.max_components:
                return False, f"Too many disconnected components: {num_components}"

        return True, ""

    def _validate_generation_topology(self, graph: MolecularGraph) -> Tuple[bool, str]:
        # Check connectivity
        valid, msg = self._validate_connectivity(graph)
        if not valid:
            return False, msg

        # Check if all motifs have reasonable connections
        isolated_motifs = []
        for motif_id in graph.motifs:
            if graph.graph.degree(motif_id) == 0:
                isolated_motifs.append(motif_id)

        if len(isolated_motifs) > len(graph.motifs) * 0.1:  # More than 10% isolated
            return False, f"Too many isolated motifs: {len(isolated_motifs)}"

        return True, ""

    def _validate_connection_topology(self, state: AssemblyState,
                                    action: AssemblyAction) -> Tuple[bool, str]:
        graph = state.current_graph

        # Check if connection would violate degree constraints
        source_degree = graph.graph.degree(action.source_motif)
        target_degree = graph.graph.degree(action.target_motif)

        if source_degree >= self.max_degree or target_degree >= self.max_degree:
            return False, f"Connection would exceed max degree {self.max_degree}"

        # Check if connection already exists
        if graph.graph.has_edge(action.source_motif, action.target_motif):
            return False, "Connection already exists"

        # Check cycle formation if not allowed
        if not self.allow_cycles:
            if self._would_create_cycle(graph, action):
                return False, "Connection would create a cycle"

        # Check for self-loops
        if action.source_motif == action.target_motif:
            return False, "Self-loops not allowed"

        return True, ""

    def _would_create_cycle(self, graph: MolecularGraph, action: AssemblyAction) -> bool:
        # Temporarily add the edge and check for cycles
        temp_graph = graph.graph.copy()
        temp_graph.add_edge(action.source_motif, action.target_motif)

        try:
            cycles = list(nx.simple_cycles(temp_graph.to_directed()))
            return len(cycles) > 0
        except:
            # For undirected graphs, check if there's already a path
            try:
                nx.shortest_path(graph.graph, action.source_motif, action.target_motif)
                return True  # Path exists, so adding edge would create cycle
            except nx.NetworkXNoPath:
                return False

    def check_termination_conditions(self, state: AssemblyState) -> Tuple[bool, str]:
        if state.mode == "reconstruction":
            return self._check_reconstruction_termination(state)
        else:
            return self._check_generation_termination(state)

    def _check_reconstruction_termination(self, state: AssemblyState) -> Tuple[bool, str]:
        if not state.target_graph:
            return False, "No target graph specified for reconstruction"

        current_graph = state.current_graph
        target_graph = state.target_graph

        # Necessary condition: connected components should be 1
        if current_graph.num_connected_components() != 1:
            return False, f"Graph has {current_graph.num_connected_components()} components, need 1"

        # Check edge completion
        current_edges = set(current_graph.graph.edges())
        target_edges = set(target_graph.graph.edges())

        if not target_edges.issubset(current_edges):
            missing = len(target_edges - current_edges)
            return False, f"Missing {missing} target edges"

        # Sufficient condition: exact match or superset with STOP
        if current_edges == target_edges:
            return True, "Perfect reconstruction achieved"

        if target_edges.issubset(current_edges):
            extra = len(current_edges - target_edges)
            return True, f"Reconstruction complete with {extra} extra edges"

        return False, "Reconstruction incomplete"

    def _check_generation_termination(self, state: AssemblyState) -> Tuple[bool, str]:
        graph = state.current_graph

        # Chemical completeness: check if all valences are reasonable
        incomplete_motifs = 0
        for motif_id, motif in graph.motifs.items():
            available_sites = motif.get_available_sites()
            if len(available_sites) > 2:  # More than 2 available sites suggests incompleteness
                incomplete_motifs += 1

        if incomplete_motifs > len(graph.motifs) * 0.2:  # More than 20% incomplete
            return False, f"Too many incomplete motifs: {incomplete_motifs}"

        # Topological completeness: should be connected
        if not nx.is_connected(graph.graph):
            return False, "Graph is not connected"

        # Size reasonableness: should have reasonable number of connections
        expected_connections = len(graph.motifs) - 1  # Minimum for connectivity
        actual_connections = len(graph.connections)

        if actual_connections < expected_connections:
            return False, f"Too few connections: {actual_connections} < {expected_connections}"

        return True, "Generation termination criteria met"

    def get_connectivity_progress(self, state: AssemblyState) -> Dict[str, float]:
        graph = state.current_graph
        num_motifs = len(graph.motifs)

        if num_motifs <= 1:
            return {'connectivity_ratio': 1.0, 'component_ratio': 1.0}

        # Connectivity progress
        num_components = graph.num_connected_components()
        max_components = num_motifs  # All isolated initially
        connectivity_ratio = 1.0 - (num_components - 1) / max(max_components - 1, 1)

        # Edge progress (for reconstruction mode)
        component_ratio = 1.0 / num_components if num_components > 0 else 0.0

        progress = {
            'connectivity_ratio': connectivity_ratio,
            'component_ratio': component_ratio,
            'num_components': num_components,
            'is_connected': float(nx.is_connected(graph.graph))
        }

        if state.target_graph:
            target_edges = set(state.target_graph.graph.edges())
            current_edges = set(graph.graph.edges())
            progress['edge_completion'] = len(current_edges & target_edges) / max(len(target_edges), 1)

        return progress

    def suggest_next_connections(self, state: AssemblyState, top_k: int = 5) -> List[Tuple[str, str, float]]:
        graph = state.current_graph

        if nx.is_connected(graph.graph):
            return []

        # Find connections that would reduce number of components
        components = list(nx.connected_components(graph.graph))
        suggestions = []

        for i, comp1 in enumerate(components):
            for j, comp2 in enumerate(components[i+1:], i+1):
                # Find best connection between components
                best_score = 0
                best_pair = None

                for motif1 in comp1:
                    for motif2 in comp2:
                        score = self._calculate_connection_score(state, motif1, motif2)
                        if score > best_score:
                            best_score = score
                            best_pair = (motif1, motif2)

                if best_pair:
                    suggestions.append((best_pair[0], best_pair[1], best_score))

        # Sort by score and return top_k
        suggestions.sort(key=lambda x: x[2], reverse=True)
        return suggestions[:top_k]

    def _calculate_connection_score(self, state: AssemblyState, motif1: str, motif2: str) -> float:
        # Simple scoring based on available connection sites and compatibility
        if motif1 not in state.current_graph.motifs or motif2 not in state.current_graph.motifs:
            return 0.0

        motif1_obj = state.current_graph.motifs[motif1]
        motif2_obj = state.current_graph.motifs[motif2]

        sites1 = motif1_obj.get_available_sites()
        sites2 = motif2_obj.get_available_sites()

        if not sites1 or not sites2:
            return 0.0

        # Score based on chemical compatibility
        compatibility_score = 0
        for site1 in sites1:
            for site2 in sites2:
                common_bonds = site1.allowed_bond_types & site2.allowed_bond_types
                if common_bonds:
                    compatibility_score += len(common_bonds) * 0.1

        # Bonus for aromatic-aromatic connections
        if motif1_obj.is_aromatic and motif2_obj.is_aromatic:
            compatibility_score += 0.5

        return compatibility_score