import torch
from typing import Dict, List, Tuple, Optional, Any, Set
from dataclasses import dataclass
from ...environment.state import AssemblyState
from ...environment.actions import AssemblyAction
from ...environment.constraints import ChemicalConstraints, TopologyConstraints
from ..actor import MotifAgent


@dataclass
class ActionProposal:
    action: AssemblyAction
    motif_id: str
    confidence: float
    reasoning: str
    chemical_score: float = 0.0
    topology_score: float = 0.0
    property_score: float = 0.0
    total_score: float = 0.0


class CentralizedCoordinator:
    def __init__(self, chemical_weight: float = 0.3, topology_weight: float = 0.4,
                 property_weight: float = 0.3, consensus_threshold: float = 0.7):
        self.chemical_weight = chemical_weight
        self.topology_weight = topology_weight
        self.property_weight = property_weight
        self.consensus_threshold = consensus_threshold

        self.chemical_constraints = ChemicalConstraints()
        self.topology_constraints = TopologyConstraints()

        self.coordination_history: List[Dict] = []

    def coordinate_actions(self, state: AssemblyState,
                          motif_agents: Dict[str, MotifAgent],
                          temperature: float = 1.0) -> Tuple[AssemblyAction, Dict[str, Any]]:

        # Phase 1: Collect proposals from all motif agents
        proposals = self._collect_proposals(state, motif_agents, temperature)

        # Phase 2: Filter by chemical and topological validity
        valid_proposals = self._filter_valid_proposals(state, proposals)

        if not valid_proposals:
            # No valid proposals - return STOP action
            stop_action = AssemblyAction.create_stop_action(
                confidence=1.0,
                reasoning="No valid connection proposals available"
            )
            return stop_action, {
                'phase': 'no_valid_proposals',
                'total_proposals': len(proposals),
                'valid_proposals': 0
            }

        # Phase 3: Score and rank proposals
        scored_proposals = self._score_proposals(state, valid_proposals)

        # Phase 4: Select best action
        selected_action, selection_info = self._select_action(state, scored_proposals)

        # Record coordination step
        coordination_record = {
            'step': state.step,
            'total_proposals': len(proposals),
            'valid_proposals': len(valid_proposals),
            'scored_proposals': len(scored_proposals),
            'selected_action': selected_action.to_dict(),
            'selection_info': selection_info,
            'top_scores': [p.total_score for p in scored_proposals[:3]]
        }
        self.coordination_history.append(coordination_record)

        return selected_action, coordination_record

    def _collect_proposals(self, state: AssemblyState,
                          motif_agents: Dict[str, MotifAgent],
                          temperature: float) -> List[ActionProposal]:
        proposals = []

        # Collect proposals from each motif agent
        for motif_id, agent in motif_agents.items():
            if motif_id not in state.available_motifs:
                continue

            try:
                action, proposal_info = agent.propose_action(state, temperature)

                proposal = ActionProposal(
                    action=action,
                    motif_id=motif_id,
                    confidence=proposal_info.get('confidence', 0.0),
                    reasoning=proposal_info.get('reasoning', '')
                )

                proposals.append(proposal)

            except Exception as e:
                # Log error but continue with other agents
                print(f"Error collecting proposal from motif {motif_id}: {e}")

        # Always include STOP action proposal
        stop_proposal = ActionProposal(
            action=AssemblyAction.create_stop_action(
                confidence=self._calculate_stop_confidence(state),
                reasoning=self._generate_stop_reasoning(state)
            ),
            motif_id="COORDINATOR",
            confidence=self._calculate_stop_confidence(state),
            reasoning=self._generate_stop_reasoning(state)
        )
        proposals.append(stop_proposal)

        return proposals

    def _filter_valid_proposals(self, state: AssemblyState,
                               proposals: List[ActionProposal]) -> List[ActionProposal]:
        valid_proposals = []

        for proposal in proposals:
            action = proposal.action

            # STOP actions are always valid
            if action.is_stop_action():
                valid_proposals.append(proposal)
                continue

            # Validate chemical constraints
            chemical_valid, chemical_msg = self.chemical_constraints.validate_action(
                action, state.current_graph
            )

            # Validate topology constraints
            topology_valid, topology_msg = self.topology_constraints.validate_topology(
                state, action
            )

            if chemical_valid and topology_valid:
                valid_proposals.append(proposal)
            else:
                # Update reasoning with validation failure
                failure_reason = []
                if not chemical_valid:
                    failure_reason.append(f"Chemical: {chemical_msg}")
                if not topology_valid:
                    failure_reason.append(f"Topology: {topology_msg}")

                proposal.reasoning += f" [INVALID: {'; '.join(failure_reason)}]"

        return valid_proposals

    def _score_proposals(self, state: AssemblyState,
                        proposals: List[ActionProposal]) -> List[ActionProposal]:
        for proposal in proposals:
            action = proposal.action

            if action.is_stop_action():
                proposal.total_score = self._score_stop_action(state)
            else:
                # Score connection action
                chemical_score = self._score_chemical_stability(state, action)
                topology_score = self._score_topology_progress(state, action)
                property_score = self._score_property_improvement(state, action)

                proposal.chemical_score = chemical_score
                proposal.topology_score = topology_score
                proposal.property_score = property_score

                proposal.total_score = (
                    self.chemical_weight * chemical_score +
                    self.topology_weight * topology_score +
                    self.property_weight * property_score
                )

        # Sort by total score (descending)
        proposals.sort(key=lambda p: p.total_score, reverse=True)

        return proposals

    def _score_chemical_stability(self, state: AssemblyState, action: AssemblyAction) -> float:
        # Score based on chemical compatibility and stability
        score = 0.0

        # Bond compatibility score
        if action.bond_type in ['SINGLE']:
            score += 0.8  # Single bonds are generally stable
        elif action.bond_type in ['DOUBLE']:
            score += 0.6  # Double bonds are moderately stable
        elif action.bond_type in ['AROMATIC']:
            score += 0.9  # Aromatic bonds are very stable
        elif action.bond_type in ['TRIPLE']:
            score += 0.4  # Triple bonds can be strained

        # Motif compatibility
        source_motif = state.current_graph.motifs[action.source_motif]
        target_motif = state.current_graph.motifs[action.target_motif]

        # Aromatic-aromatic bonus
        if source_motif.is_aromatic and target_motif.is_aromatic:
            score += 0.2

        # Functional group compatibility
        common_groups = set(source_motif.functional_groups) & set(target_motif.functional_groups)
        score += len(common_groups) * 0.1

        return min(score, 1.0)

    def _score_topology_progress(self, state: AssemblyState, action: AssemblyAction) -> float:
        # Score based on topological progress toward connectivity goals
        score = 0.0

        current_graph = state.current_graph
        current_components = current_graph.num_connected_components()

        # Simulate the connection to see topology impact
        source_component = None
        target_component = None

        for component in current_graph.get_connected_components():
            if action.source_motif in component:
                source_component = component
            if action.target_motif in component:
                target_component = component

        # If connecting different components, high score for reducing fragmentation
        if source_component != target_component:
            score += 0.8  # High reward for connecting components

        # Distance-based scoring for reconstruction mode
        if state.mode == "reconstruction" and state.target_graph:
            target_edges = set(state.target_graph.graph.edges())
            proposed_edge = (action.source_motif, action.target_motif)
            reverse_edge = (action.target_motif, action.source_motif)

            if proposed_edge in target_edges or reverse_edge in target_edges:
                score += 0.6  # Bonus for correct target edge

        # Penalty for over-connection in reconstruction mode
        if state.mode == "reconstruction" and state.target_graph:
            current_edges = set(current_graph.graph.edges())
            target_edges = set(state.target_graph.graph.edges())

            if len(current_edges) >= len(target_edges):
                score -= 0.3  # Penalty for exceeding target

        return max(score, 0.0)

    def _score_property_improvement(self, state: AssemblyState, action: AssemblyAction) -> float:
        # Score based on expected property improvement
        # This is simplified - in practice would use property prediction models
        score = 0.5  # Neutral baseline

        # Bonus for maintaining reasonable molecular size
        current_size = len(state.current_graph.motifs)
        if 5 <= current_size <= 20:  # Reasonable drug-like size
            score += 0.2

        # Property-specific considerations would go here
        # For now, return neutral score
        return score

    def _score_stop_action(self, state: AssemblyState) -> float:
        # Score STOP action based on current state completion
        if state.mode == "reconstruction" and state.target_graph:
            current_edges = set(state.current_graph.graph.edges())
            target_edges = set(state.target_graph.graph.edges())

            completion_ratio = len(current_edges & target_edges) / len(target_edges) if target_edges else 1.0
            connectivity_achieved = state.current_graph.num_connected_components() == 1

            if completion_ratio >= 0.9 and connectivity_achieved:
                return 1.0  # High score for near-complete reconstruction
            elif completion_ratio >= 0.7:
                return 0.6  # Moderate score for partial completion
            else:
                return 0.1  # Low score for incomplete reconstruction

        else:  # Generation mode
            connectivity_achieved = state.current_graph.num_connected_components() == 1
            reasonable_size = len(state.current_graph.connections) >= len(state.current_graph.motifs) - 1

            if connectivity_achieved and reasonable_size:
                return 0.8  # High score for connected generation
            else:
                return 0.2  # Low score for incomplete generation

    def _select_action(self, state: AssemblyState,
                      scored_proposals: List[ActionProposal]) -> Tuple[AssemblyAction, Dict[str, Any]]:
        if not scored_proposals:
            stop_action = AssemblyAction.create_stop_action(
                confidence=1.0,
                reasoning="No proposals to select from"
            )
            return stop_action, {'selection_method': 'fallback_stop'}

        # Select top action
        top_proposal = scored_proposals[0]

        # Check for consensus (multiple high-scoring similar actions)
        consensus_actions = [p for p in scored_proposals[:3] if p.total_score >= self.consensus_threshold]

        selection_info = {
            'selection_method': 'top_score',
            'selected_score': top_proposal.total_score,
            'consensus_count': len(consensus_actions),
            'score_distribution': {
                'chemical': top_proposal.chemical_score,
                'topology': top_proposal.topology_score,
                'property': top_proposal.property_score,
                'total': top_proposal.total_score
            },
            'reasoning': top_proposal.reasoning
        }

        # Add consensus information to reasoning
        if len(consensus_actions) > 1:
            top_proposal.action.reasoning += f" [CONSENSUS: {len(consensus_actions)} agents agree]"

        return top_proposal.action, selection_info

    def _calculate_stop_confidence(self, state: AssemblyState) -> float:
        # Calculate confidence for STOP action based on state
        if state.mode == "reconstruction" and state.target_graph:
            progress = self._get_reconstruction_progress(state)
            return min(progress * 1.2, 1.0)  # Scale progress to confidence
        else:
            connectivity = 1.0 if state.current_graph.num_connected_components() == 1 else 0.3
            return connectivity

    def _generate_stop_reasoning(self, state: AssemblyState) -> str:
        reasoning_parts = []

        if state.mode == "reconstruction":
            if state.target_graph:
                progress = self._get_reconstruction_progress(state)
                reasoning_parts.append(f"Reconstruction progress: {progress:.1%}")

        connectivity = state.current_graph.num_connected_components()
        reasoning_parts.append(f"Connected components: {connectivity}")

        if connectivity == 1:
            reasoning_parts.append("Molecule is fully connected")

        return "STOP: " + "; ".join(reasoning_parts)

    def _get_reconstruction_progress(self, state: AssemblyState) -> float:
        if not state.target_graph:
            return 0.0

        current_edges = set(state.current_graph.graph.edges())
        target_edges = set(state.target_graph.graph.edges())

        if not target_edges:
            return 1.0

        correct_edges = len(current_edges & target_edges)
        return correct_edges / len(target_edges)

    def get_coordination_summary(self) -> Dict[str, Any]:
        if not self.coordination_history:
            return {}

        recent_steps = self.coordination_history[-10:]  # Last 10 steps

        summary = {
            'total_steps': len(self.coordination_history),
            'recent_proposal_counts': [step['total_proposals'] for step in recent_steps],
            'recent_valid_ratios': [
                step['valid_proposals'] / max(step['total_proposals'], 1)
                for step in recent_steps
            ],
            'recent_top_scores': [step['top_scores'][0] if step['top_scores'] else 0.0 for step in recent_steps],
            'average_proposals_per_step': sum(step['total_proposals'] for step in recent_steps) / len(recent_steps) if recent_steps else 0
        }

        return summary