"""
Block-specific evaluators for port connectivity analysis.

This module contains evaluators for specific block types that determine
how effectively each block is connected based on its port connectivity patterns.
Different block types have different requirements for earning rewards.

Key Features:
- Type-specific evaluation rules for different block types
- Extensible architecture for adding new block types
- Clear reward calculation based on port connectivity
- Scalable and Pythonic design
"""

from typing import Dict, List, Set, Optional, Protocol
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging

from system_parser.core.models import UnconnectedPort
try:
    from system_parser.graph.system_graph import SystemGraph
except ImportError:
    # Fallback to legacy system graph
    from system_parser.system_graph import SystemGraph


logger = logging.getLogger(__name__)


@dataclass
class BlockConnectivityResult:
    """Result of block connectivity evaluation."""
    block_name: str
    block_type: str
    connectivity_score: float  # 0.0 to 1.0
    is_effective: bool
    connected_ports: Set[str]
    unconnected_ports: Set[str]
    total_ports: int
    evaluation_details: Dict[str, any] = None
    
    def __post_init__(self):
        if self.evaluation_details is None:
            self.evaluation_details = {}


class BlockConnectivityEvaluator(Protocol):
    """Protocol for block connectivity evaluators."""
    
    def evaluate(
        self, 
        block_name: str, 
        block_type: str,
        all_ports: Set[str],
        unconnected_ports: Set[str]
    ) -> BlockConnectivityResult:
        """Evaluate the connectivity effectiveness of a block based on its connectivity."""
        ...


class BaseBlockEvaluator(ABC):
    """Base class for block connectivity evaluators."""
    
    @abstractmethod
    def evaluate(
        self, 
        block_name: str, 
        block_type: str,
        all_ports: Set[str],
        unconnected_ports: Set[str]
    ) -> BlockConnectivityResult:
        """Evaluate block connectivity effectiveness."""
        pass
    
    def _create_result(
        self, 
        block_name: str, 
        block_type: str,
        all_ports: Set[str],
        unconnected_ports: Set[str],
        connectivity_score: float,
        is_effective: bool,
        evaluation_details: Dict[str, any] = None
    ) -> BlockConnectivityResult:
        """Helper to create evaluation result."""
        connected_ports = all_ports - unconnected_ports
        
        return BlockConnectivityResult(
            block_name=block_name,
            block_type=block_type,
            connectivity_score=connectivity_score,
            is_effective=is_effective,
            connected_ports=connected_ports,
            unconnected_ports=unconnected_ports,
            total_ports=len(all_ports),
            evaluation_details=evaluation_details or {}
        )


class ThreePhaseVIMeasurementEvaluator(BaseBlockEvaluator):
    """
    Evaluator for Three-Phase V-I Measurement blocks.
    
    Requirements for effectiveness:
    - Both sides must be connected (left side: a1,b1,c1 and right side: a2,b2,c2)
    - Reward is given only when both sides are fully connected
    - No reward if only one side is connected
    """
    
    def evaluate(
        self, 
        block_name: str, 
        block_type: str,
        all_ports: Set[str],
        unconnected_ports: Set[str]
    ) -> BlockConnectivityResult:
        """Evaluate Three-Phase V-I Measurement block connectivity effectiveness."""
        
        # Define the two sides
        left_side_ports = {'a1', 'b1', 'c1'}
        right_side_ports = {'a2', 'b2', 'c2'}
        
        # Check what ports are actually available
        available_left = left_side_ports.intersection(all_ports)
        available_right = right_side_ports.intersection(all_ports)
        
        # Check connectivity of each side
        unconnected_left = unconnected_ports.intersection(available_left)
        unconnected_right = unconnected_ports.intersection(available_right)
        
        # Both sides are connected if no ports on either side are unconnected
        left_side_connected = len(unconnected_left) == 0 and len(available_left) > 0
        right_side_connected = len(unconnected_right) == 0 and len(available_right) > 0
        
        # Effectiveness: both sides must be fully connected
        both_sides_connected = left_side_connected and right_side_connected
        
        # Calculate score
        if both_sides_connected:
            connectivity_score = 1.0
            is_effective = True
        else:
            connectivity_score = 0.0
            is_effective = False
        
        evaluation_details = {
            'left_side_ports': list(available_left),
            'right_side_ports': list(available_right),
            'left_side_connected': left_side_connected,
            'right_side_connected': right_side_connected,
            'unconnected_left': list(unconnected_left),
            'unconnected_right': list(unconnected_right),
            'both_sides_connected': both_sides_connected,
            'requirement': 'Both sides (a1,b1,c1 and a2,b2,c2) must be fully connected'
        }
        
        return self._create_result(
            block_name, block_type, all_ports, unconnected_ports,
            connectivity_score, is_effective, evaluation_details
        )


class ThreePhaseParallelRLCLoadEvaluator(BaseBlockEvaluator):
    """
    Evaluator for Three-Phase Parallel RLC Load blocks.
    
    Requirements for effectiveness:
    - Both sides must be connected (left side: a1,b1,c1 and right side: a2,b2,c2)
    - Reward is given only when both sides are fully connected
    - No reward if only one side is connected
    """
    
    def evaluate(
        self, 
        block_name: str, 
        block_type: str,
        all_ports: Set[str],
        unconnected_ports: Set[str]
    ) -> BlockConnectivityResult:
        """Evaluate Three-Phase Parallel RLC Load block connectivity effectiveness."""
        
        # Define the two sides
        left_side_ports = {'a1', 'b1', 'c1'}
        right_side_ports = {'a2', 'b2', 'c2'}
        
        # Check what ports are actually available
        available_left = left_side_ports.intersection(all_ports)
        available_right = right_side_ports.intersection(all_ports)
        
        # Check connectivity of each side
        unconnected_left = unconnected_ports.intersection(available_left)
        unconnected_right = unconnected_ports.intersection(available_right)
        
        # Both sides are connected if no ports on either side are unconnected
        left_side_connected = len(unconnected_left) == 0 and len(available_left) > 0
        right_side_connected = len(unconnected_right) == 0 and len(available_right) > 0
        
        # Effectiveness: both sides must be fully connected
        both_sides_connected = left_side_connected and right_side_connected
        
        # Calculate score
        if both_sides_connected:
            connectivity_score = 1.0
            is_effective = True
        else:
            connectivity_score = 0.0
            is_effective = False
        
        evaluation_details = {
            'left_side_ports': list(available_left),
            'right_side_ports': list(available_right),
            'left_side_connected': left_side_connected,
            'right_side_connected': right_side_connected,
            'unconnected_left': list(unconnected_left),
            'unconnected_right': list(unconnected_right),
            'both_sides_connected': both_sides_connected,
            'requirement': 'Both sides (a1,b1,c1 and a2,b2,c2) must be fully connected'
        }
        
        return self._create_result(
            block_name, block_type, all_ports, unconnected_ports,
            connectivity_score, is_effective, evaluation_details
        )


class DefaultBlockEvaluator(BaseBlockEvaluator):
    """
    Default evaluator for blocks without specific requirements.
    
    Simple connectivity-based scoring:
    - Score = (connected_ports / total_ports)
    - Effective if more than 50% of ports are connected
    """
    
    def __init__(self, effectiveness_threshold: float = 0.5):
        self.effectiveness_threshold = effectiveness_threshold
    
    def evaluate(
        self, 
        block_name: str, 
        block_type: str,
        all_ports: Set[str],
        unconnected_ports: Set[str]
    ) -> BlockConnectivityResult:
        """Evaluate block with simple connectivity scoring."""
        
        total_ports = len(all_ports)
        connected_ports = all_ports - unconnected_ports
        num_connected = len(connected_ports)
        
        # Calculate connectivity score
        if total_ports > 0:
            connectivity_score = num_connected / total_ports
        else:
            connectivity_score = 1.0  # No ports means trivially effective
        
        is_effective = connectivity_score >= self.effectiveness_threshold
        
        evaluation_details = {
            'total_ports': total_ports,
            'connected_ports_count': num_connected,
            'unconnected_ports_count': len(unconnected_ports),
            'effectiveness_threshold': self.effectiveness_threshold,
            'requirement': f'At least {self.effectiveness_threshold:.0%} of ports must be connected'
        }
        
        return self._create_result(
            block_name, block_type, all_ports, unconnected_ports,
            connectivity_score, is_effective, evaluation_details
        )


class BlockConnectivityReward:
    """
    Main reward system for evaluating block port connectivity.
    
    This system evaluates how effectively each block is connected
    based on type-specific requirements.
    """
    
    def __init__(self):
        self.evaluators: Dict[str, BaseBlockEvaluator] = {}
        self.default_evaluator = DefaultBlockEvaluator()
        
        # Register built-in evaluators
        self._register_builtin_evaluators()
    
    def _register_builtin_evaluators(self):
        """Register built-in block evaluators."""
        self.register_evaluator("Three-Phase V-I Measurement", ThreePhaseVIMeasurementEvaluator())
        self.register_evaluator("Three-Phase Parallel RLC Load", ThreePhaseParallelRLCLoadEvaluator())
    
    def register_evaluator(self, block_type: str, evaluator: BaseBlockEvaluator):
        """Register an evaluator for a specific block type."""
        self.evaluators[block_type] = evaluator
        logger.info(f"Registered evaluator for block type: {block_type}")
    
    def get_evaluator(self, block_type: str) -> BaseBlockEvaluator:
        """Get the appropriate evaluator for a block type."""
        return self.evaluators.get(block_type, self.default_evaluator)
    
    def evaluate_block(
        self, 
        block_name: str, 
        block_type: str,
        all_ports: Set[str],
        unconnected_ports: Set[str]
    ) -> BlockConnectivityResult:
        """Evaluate a single block's connectivity effectiveness."""
        evaluator = self.get_evaluator(block_type)
        return evaluator.evaluate(block_name, block_type, all_ports, unconnected_ports)
    
    def evaluate_system(self, system_graph: SystemGraph) -> Dict[str, BlockConnectivityResult]:
        """
        Evaluate all blocks in a system for connectivity effectiveness.
        
        Args:
            system_graph: The system graph to evaluate
            
        Returns:
            Dictionary mapping block names to their connectivity results
        """
        results = {}
        
        try:
            # Get unconnected ports grouped by block
            unconnected_by_block = system_graph.get_unconnected_ports_by_block()
            
            # Evaluate each block
            for block_name, block in system_graph.blocks.items():
                # Get all ports for this block
                all_ports = set(block.ports.keys())
                
                # Get unconnected ports for this block
                unconnected_ports_list = unconnected_by_block.get(block_name, [])
                unconnected_ports = {port.port_name for port in unconnected_ports_list}
                
                # Evaluate the block
                result = self.evaluate_block(
                    block_name, block.type, all_ports, unconnected_ports
                )
                results[block_name] = result
                
        except Exception as e:
            logger.error(f"Failed to evaluate system blocks: {e}")
        
        return results
    
    def calculate_system_connectivity_score(self, 
                                          results: Dict[str, BlockConnectivityResult], 
                                          target_block_types: Optional[Set[str]] = None) -> float:
        """
        Calculate system connectivity score for specific block types.
        
        Args:
            results: Block evaluation results
            target_block_types: Block types to consider. If None, defaults to {"Three-Phase V-I Measurement", "Three-Phase Parallel RLC Load"}
            
        Returns:
            System connectivity score (0.0 to 1.0)
            - For Three-Phase blocks (V-I Measurement, Parallel RLC Load): ratio of blocks with both sides connected
            - For other block types: average connectivity score
        """
        if not results:
            return 0.0
        
        # Default to Three-Phase V-I Measurement and Three-Phase Parallel RLC Load blocks
        if target_block_types is None:
            target_block_types = {"Three-Phase V-I Measurement", "Three-Phase Parallel RLC Load"}
        
        # Filter results by target block types
        filtered_results = {
            name: result for name, result in results.items() 
            if result.block_type in target_block_types
        }
        
        if not filtered_results:
            return 0.0
        
        # Special logic for Three-Phase blocks (V-I Measurement and Parallel RLC Load)
        three_phase_block_types = {"Three-Phase V-I Measurement", "Three-Phase Parallel RLC Load"}
        if target_block_types.issubset(three_phase_block_types):
            # Return ratio of effective Three-Phase blocks
            effective_count = sum(1 for result in filtered_results.values() if result.is_effective)
            return effective_count / len(filtered_results)
        else:
            # For other block types or mixed types, use average connectivity score
            total_score = sum(result.connectivity_score for result in filtered_results.values())
            return total_score / len(filtered_results)
    
    def calculate_effective_blocks_ratio(self, results: Dict[str, BlockConnectivityResult]) -> float:
        """
        Calculate the ratio of effective blocks.
        
        Args:
            results: Block evaluation results
            
        Returns:
            Ratio of effective blocks (0.0 to 1.0)
        """
        if not results:
            return 0.0
        
        effective_count = sum(1 for result in results.values() if result.is_effective)
        return effective_count / len(results)
    
    def get_connectivity_summary(self, results: Dict[str, BlockConnectivityResult]) -> Dict[str, any]:
        """Get a summary of connectivity evaluation."""
        if not results:
            return {
                'total_blocks': 0,
                'effective_blocks': 0,
                'effectiveness_ratio': 0.0,
                'average_score': 0.0,
                'by_type': {}
            }
        
        # Group results by block type
        by_type = {}
        for result in results.values():
            block_type = result.block_type
            if block_type not in by_type:
                by_type[block_type] = {
                    'count': 0,
                    'effective': 0,
                    'total_score': 0.0
                }
            
            by_type[block_type]['count'] += 1
            by_type[block_type]['total_score'] += result.connectivity_score
            if result.is_effective:
                by_type[block_type]['effective'] += 1
        
        # Calculate averages by type
        for type_stats in by_type.values():
            type_stats['average_score'] = type_stats['total_score'] / type_stats['count']
            type_stats['effectiveness_ratio'] = type_stats['effective'] / type_stats['count']
        
        return {
            'total_blocks': len(results),
            'effective_blocks': sum(1 for r in results.values() if r.is_effective),
            'effectiveness_ratio': self.calculate_effective_blocks_ratio(results),
            'average_score': self.calculate_system_connectivity_score(results),
            'by_type': by_type
        }


# Convenience function for easy integration
def calculate_block_connectivity_reward(system_graph: SystemGraph, 
                                      target_block_types: Optional[Set[str]] = None) -> float:
    """
    Calculate block connectivity reward for a system.
    
    Args:
        system_graph: The system graph to evaluate
        target_block_types: Block types to consider. If None, defaults to {"Three-Phase V-I Measurement", "Three-Phase Parallel RLC Load"}
        
    Returns:
        Block connectivity reward score (0.0 to 1.0)
    """
    reward_system = BlockConnectivityReward()
    results = reward_system.evaluate_system(system_graph)
    return reward_system.calculate_system_connectivity_score(results, target_block_types)
