from typing import Dict, List, Set, Tuple, Optional
from rdkit import Chem
from ...core.representation import MolecularGraph, Motif, Connection
from ..actions import AssemblyAction


class ChemicalConstraints:
    def __init__(self):
        self.valence_rules = self._initialize_valence_rules()
        self.bond_compatibility = self._initialize_bond_compatibility()
        self.forbidden_patterns = self._initialize_forbidden_patterns()

    def _initialize_valence_rules(self) -> Dict[int, Dict[str, int]]:
        return {
            6: {'max_valence': 4, 'common_valences': [4]},      # Carbon
            7: {'max_valence': 5, 'common_valences': [3, 5]},   # Nitrogen
            8: {'max_valence': 2, 'common_valences': [2]},      # Oxygen
            16: {'max_valence': 6, 'common_valences': [2, 4, 6]}, # Sulfur
            15: {'max_valence': 5, 'common_valences': [3, 5]},  # Phosphorus
            9: {'max_valence': 1, 'common_valences': [1]},      # Fluorine
            17: {'max_valence': 1, 'common_valences': [1]},     # Chlorine
            35: {'max_valence': 1, 'common_valences': [1]},     # Bromine
            53: {'max_valence': 1, 'common_valences': [1]}      # Iodine
        }

    def _initialize_bond_compatibility(self) -> Dict[Tuple[str, str], Set[str]]:
        return {
            ('C_sp3', 'C_sp3'): {'SINGLE'},
            ('C_sp3', 'C_sp2'): {'SINGLE'},
            ('C_sp3', 'C_sp'): {'SINGLE'},
            ('C_sp2', 'C_sp2'): {'SINGLE', 'DOUBLE'},
            ('C_sp2', 'C_sp'): {'SINGLE'},
            ('C_sp', 'C_sp'): {'SINGLE', 'TRIPLE'},
            ('C_sp3', 'N_sp3'): {'SINGLE'},
            ('C_sp3', 'N_sp2'): {'SINGLE'},
            ('C_sp2', 'N_sp3'): {'SINGLE'},
            ('C_sp2', 'N_sp2'): {'SINGLE', 'DOUBLE'},
            ('C_sp3', 'O_sp3'): {'SINGLE'},
            ('C_sp2', 'O_sp3'): {'SINGLE'},
            ('C_sp2', 'O_sp2'): {'SINGLE', 'DOUBLE'},
            ('C_sp3', 'S_sp3'): {'SINGLE'},
            ('C_sp2', 'S_sp3'): {'SINGLE'},
            ('C_sp2', 'S_sp2'): {'SINGLE', 'DOUBLE'},
            ('aromatic', 'aromatic'): {'AROMATIC', 'SINGLE'},
        }

    def _initialize_forbidden_patterns(self) -> List[str]:
        return [
            'C#C#C',        # Adjacent triple bonds
            'C=C=C=C',      # Long cumulative double bonds
            '[N+]#[N+]',    # Adjacent positive nitrogens with triple bond
            'O=O',          # Direct O=O bond (unstable)
            '[OH][OH]',     # Adjacent OH groups (prefer H2O2 with single bond)
            'C1C1',         # Two-membered carbon ring
            'C1N1',         # Two-membered heterocycle
        ]

    def validate_action(self, action: AssemblyAction, graph: MolecularGraph) -> Tuple[bool, str]:
        if action.is_stop_action():
            return True, ""

        if not action.is_valid_connect_action():
            return False, "Invalid connection action format"

        # Check if motifs exist
        if (action.source_motif not in graph.motifs or
            action.target_motif not in graph.motifs):
            return False, "Source or target motif not found"

        # Check if connection already exists
        if graph.graph.has_edge(action.source_motif, action.target_motif):
            return False, "Connection already exists between these motifs"

        # Validate valence constraints
        valid, msg = self._validate_valence_constraints(action, graph)
        if not valid:
            return False, msg

        # Validate bond compatibility
        valid, msg = self._validate_bond_compatibility(action, graph)
        if not valid:
            return False, msg

        # Check for forbidden patterns
        valid, msg = self._check_forbidden_patterns(action, graph)
        if not valid:
            return False, msg

        return True, ""

    def _validate_valence_constraints(self, action: AssemblyAction, graph: MolecularGraph) -> Tuple[bool, str]:
        source_motif = graph.motifs[action.source_motif]
        target_motif = graph.motifs[action.target_motif]

        # Get connection sites
        source_sites = {site.site_id: site for site in source_motif.connection_sites}
        target_sites = {site.site_id: site for site in target_motif.connection_sites}

        if (action.source_site not in source_sites or
            action.target_site not in target_sites):
            return False, "Connection site not found"

        source_site = source_sites[action.source_site]
        target_site = target_sites[action.target_site]

        # Check if bond type is allowed
        if (action.bond_type not in source_site.allowed_bond_types or
            action.bond_type not in target_site.allowed_bond_types):
            return False, f"Bond type {action.bond_type} not allowed for these sites"

        return True, ""

    def _validate_bond_compatibility(self, action: AssemblyAction, graph: MolecularGraph) -> Tuple[bool, str]:
        source_motif = graph.motifs[action.source_motif]
        target_motif = graph.motifs[action.target_motif]

        source_sites = {site.site_id: site for site in source_motif.connection_sites}
        target_sites = {site.site_id: site for site in target_motif.connection_sites}

        source_site = source_sites[action.source_site]
        target_site = target_sites[action.target_site]

        # Check type compatibility
        type_pair = (source_site.site_type, target_site.site_type)
        reverse_pair = (target_site.site_type, source_site.site_type)

        compatible_bonds = set()
        if type_pair in self.bond_compatibility:
            compatible_bonds.update(self.bond_compatibility[type_pair])
        if reverse_pair in self.bond_compatibility:
            compatible_bonds.update(self.bond_compatibility[reverse_pair])

        # Handle aromatic cases
        if source_site.is_aromatic and target_site.is_aromatic:
            compatible_bonds.add('AROMATIC')

        if action.bond_type not in compatible_bonds and compatible_bonds:
            return False, f"Bond type {action.bond_type} not compatible with {type_pair}"

        return True, ""

    def _check_forbidden_patterns(self, action: AssemblyAction, graph: MolecularGraph) -> Tuple[bool, str]:
        # This is a simplified check - in practice, you'd need to construct
        # the molecule and check for substructure matches
        try:
            # Create temporary connection to test
            temp_graph = graph
            temp_connection = Connection(
                source_motif=action.source_motif,
                source_site=action.source_site,
                target_motif=action.target_motif,
                target_site=action.target_site,
                bond_type=action.bond_type
            )

            # Basic checks for problematic connections
            if self._would_create_strain(temp_connection, graph):
                return False, "Connection would create ring strain"

            if self._would_violate_aromaticity(temp_connection, graph):
                return False, "Connection would violate aromaticity rules"

        except Exception as e:
            return False, f"Error validating chemical constraints: {str(e)}"

        return True, ""

    def _would_create_strain(self, connection: Connection, graph: MolecularGraph) -> bool:
        # Simple check for 2-membered rings
        source_connections = [c for c in graph.connections if c.source_motif == connection.source_motif or c.target_motif == connection.source_motif]
        target_connections = [c for c in graph.connections if c.source_motif == connection.target_motif or c.target_motif == connection.target_motif]

        # Check if this would create a direct back-connection (2-membered ring)
        for src_conn in source_connections:
            for tgt_conn in target_connections:
                if (src_conn.target_motif == tgt_conn.source_motif or
                    src_conn.source_motif == tgt_conn.target_motif):
                    return True

        return False

    def _would_violate_aromaticity(self, connection: Connection, graph: MolecularGraph) -> bool:
        source_motif = graph.motifs[connection.source_motif]
        target_motif = graph.motifs[connection.target_motif]

        # Check aromatic-aromatic connections
        if source_motif.is_aromatic and target_motif.is_aromatic:
            if connection.bond_type not in ['AROMATIC', 'SINGLE']:
                return True

        return False

    def get_constraint_violations(self, graph: MolecularGraph) -> List[Dict]:
        violations = []

        for connection in graph.connections:
            action = AssemblyAction.create_connect_action(
                source_motif=connection.source_motif,
                source_site=connection.source_site,
                target_motif=connection.target_motif,
                target_site=connection.target_site,
                bond_type=connection.bond_type
            )

            valid, msg = self.validate_action(action, graph)
            if not valid:
                violations.append({
                    'connection': connection,
                    'violation': msg,
                    'severity': 'error'
                })

        return violations

    def suggest_corrections(self, violations: List[Dict]) -> List[Dict]:
        suggestions = []

        for violation in violations:
            connection = violation['connection']
            msg = violation['violation']

            suggestion = {
                'violation': violation,
                'suggested_actions': []
            }

            if 'Bond type' in msg and 'not allowed' in msg:
                # Suggest alternative bond types
                source_motif = connection.source_motif
                target_motif = connection.target_motif
                suggestion['suggested_actions'].append({
                    'type': 'change_bond_type',
                    'alternatives': ['SINGLE', 'DOUBLE', 'AROMATIC']
                })

            suggestions.append(suggestion)

        return suggestions