"""
Unified electrical analysis system for power systems.

This module provides a comprehensive, object-oriented approach to electrical system analysis,
consolidating functionality previously scattered across multiple files.
"""

import logging
from typing import Dict, List, Any, Optional, Tuple, Set
from dataclasses import dataclass
from abc import ABC, abstractmethod

logger = logging.getLogger(__name__)


@dataclass
class AnalysisResult:
    """Results from electrical analysis operations."""
    analyzer_name: str
    status: str  # "PASS", "FAIL", "WARNING", "INFO"
    violations: List[str]
    score: float  # 0.0 to 1.0
    details: Dict[str, Any]
    summary: str


class BaseElectricalAnalyzer(ABC):
    """Base class for electrical analyzers."""
    
    def __init__(self, name: str):
        self.name = name
    
    @abstractmethod
    def analyze(self, system_graph) -> AnalysisResult:
        """Perform analysis and return results."""
        pass
    
    def _create_result(self, status: str, violations: List[str], score: float, 
                      details: Dict[str, Any], summary: str) -> AnalysisResult:
        """Helper method to create analysis results."""
        return AnalysisResult(
            analyzer_name=self.name,
            status=status,
            violations=violations,
            score=score,
            details=details,
            summary=summary
        )


class VoltageCoherenceAnalyzer(BaseElectricalAnalyzer):
    """Analyzes voltage coherence across connected electrical networks."""
    
    def __init__(self, threshold: float = 0.2):
        super().__init__("voltage_coherence")
        self.threshold = threshold
    
    def analyze(self, system_graph) -> AnalysisResult:
        """Analyze voltage coherence across the electrical network."""
        # Perform voltage coherence analysis directly
        violations = self._perform_voltage_analysis(system_graph)
        num_violations = len(violations)
        
        # Calculate score based on violations
        score = max(0.0, min(1.0, 1.0 / (1.0 + num_violations)))
        status = "PASS" if num_violations == 0 else "FAIL"
        
        return self._create_result(
            status=status,
            violations=violations,
            score=score,
            details={
                "threshold": self.threshold,
                "total_violations": num_violations,
                "voltage_networks": self._analyze_voltage_networks(system_graph)
            },
            summary=f"Voltage coherence: {num_violations} violations detected, score: {score:.3f}"
        )
    
    def _perform_voltage_analysis(self, system_graph) -> List[str]:
        """Perform voltage coherence analysis and return violations."""
        violations = []
        
        # Get voltage information for each block
        block_voltages = {}
        for block_name, block in system_graph.blocks.items():
            voltage = self._extract_voltage(block)
            if voltage is not None:
                block_voltages[block_name] = voltage
        
        if len(block_voltages) < 2:
            return []  # Not enough voltage information for comparison
        
        # Find all electrically connected networks using graph traversal
        voltage_networks = self._find_voltage_networks(system_graph, block_voltages)
        
        # Check voltage coherence within each network
        for network in voltage_networks:
            if len(network) < 2:
                continue  # Need at least 2 blocks with voltage to compare
            
            # Check all pairs within this network
            network_blocks = list(network.keys())
            for i in range(len(network_blocks)):
                for j in range(i + 1, len(network_blocks)):
                    block1 = network_blocks[i]
                    block2 = network_blocks[j]
                    voltage1 = network[block1]
                    voltage2 = network[block2]
                    
                    violation = self._check_voltage_compatibility(block1, block2, voltage1, voltage2)
                    if violation:
                        violations.append(violation)
        
        return violations
    
    def _find_voltage_networks(self, system_graph, block_voltages: Dict[str, float]) -> List[Dict[str, float]]:
        """Find separate electrical networks containing voltage-bearing blocks."""
        visited = set()
        networks = []
        
        for block_name in block_voltages:
            if block_name in visited:
                continue
            
            # Find all blocks connected to this one (including through intermediate blocks)
            network = self._explore_electrical_network(system_graph, block_name, block_voltages, visited)
            
            if len(network) >= 2:  # Only include networks with multiple voltage-bearing blocks
                networks.append(network)
        
        return networks
    
    def _explore_electrical_network(self, system_graph, start_block: str, block_voltages: Dict[str, float], visited: Set[str]) -> Dict[str, float]:
        """Explore an electrical network starting from a block, finding all voltage-bearing blocks."""
        network = {}
        to_explore = [start_block]
        local_visited = set()
        
        while to_explore:
            current_block = to_explore.pop()
            
            if current_block in local_visited:
                continue
            
            local_visited.add(current_block)
            
            # If this block has voltage, add it to the network
            if current_block in block_voltages:
                network[current_block] = block_voltages[current_block]
                visited.add(current_block)  # Mark as visited globally
            
            # Find all blocks connected to this one
            connected_blocks = self._get_connected_blocks(system_graph, current_block)
            
            # Add connected blocks to exploration queue
            for connected_block in connected_blocks:
                if connected_block not in local_visited:
                    to_explore.append(connected_block)
        
        return network
    
    def _get_connected_blocks(self, system_graph, block_name: str) -> List[str]:
        """Get all blocks directly connected to the given block."""
        connected_blocks = set()
        
        # Find all edges involving this block
        for edge in system_graph.graph.edges(data=True):
            if edge[2].get("edge_type") != "connects_to":
                continue
            
            src_port_id = edge[0]
            tgt_port_id = edge[1]
            
            # Get block names for source and target ports
            src_data = system_graph.graph.nodes[src_port_id]
            tgt_data = system_graph.graph.nodes[tgt_port_id]
            
            src_block = src_data.get("block_name")
            tgt_block = tgt_data.get("block_name")
            
            # If one of the blocks is our target, add the other to connected list
            if src_block == block_name and tgt_block:
                connected_blocks.add(tgt_block)
            elif tgt_block == block_name and src_block:
                connected_blocks.add(src_block)
        
        return list(connected_blocks)
    
    def _check_voltage_compatibility(self, block1: str, block2: str, voltage1: float, voltage2: float) -> Optional[str]:
        """Check if two voltages are compatible within the given threshold."""
        # Calculate relative difference
        max_voltage = max(voltage1, voltage2)
        min_voltage = min(voltage1, voltage2)
        
        if max_voltage == 0:
            return None  # Avoid division by zero
        
        relative_diff = abs(voltage1 - voltage2) / max_voltage
        
        # Use different threshold for transformer connections
        # (This would need system_graph context to check block types)
        actual_threshold = self.threshold
        
        if relative_diff > actual_threshold:
            voltage_diff_percent = relative_diff * 100
            return (f"Voltage mismatch between '{block1}' ({voltage1:.0f} V) and '{block2}' ({voltage2:.0f} V): "
                   f"{voltage_diff_percent:.1f}% difference")
        
        return None
    
    def _analyze_voltage_networks(self, system_graph) -> Dict[str, Any]:
        """Analyze voltage networks for detailed reporting."""
        try:
            # Extract voltage information for each block
            block_voltages = {}
            for block_name, block in system_graph.blocks.items():
                voltage = self._extract_voltage(block)
                if voltage is not None:
                    block_voltages[block_name] = voltage
            
            return {
                "blocks_with_voltage": len(block_voltages),
                "voltage_values": block_voltages,
                "unique_voltages": len(set(block_voltages.values())) if block_voltages else 0
            }
        except Exception as e:
            logger.warning(f"Error analyzing voltage networks: {e}")
            return {}
    
    def _extract_voltage(self, block) -> Optional[float]:
        """Extract voltage value from a block's parameters."""
        params = block.params
        
        # Different blocks have different voltage parameter names
        voltage_keys = [
            "Phase-to-phase voltage (Vrms)",
            "Nominal phase-to-phase voltage Vn (Vrms)"
        ]
        
        for key in voltage_keys:
            if key in params:
                try:
                    return float(params[key])
                except (ValueError, TypeError):
                    continue
        
        # Handle transformer voltages (take primary winding voltage)
        winding_keys = [
            "Winding 1 parameters [V1 Ph-Ph(Vrms), R1(pu), L1(pu)]",
            "Winding 2 parameters [V2 Ph-Ph(Vrms), R2(pu), L2(pu)]"
        ]
        
        for key in winding_keys:
            if key in params:
                try:
                    winding_str = params[key]
                    # Parse array format like "[735e3, 0.002, 0.08]"
                    import ast
                    winding_list = ast.literal_eval(winding_str)
                    if isinstance(winding_list, list) and len(winding_list) >= 1:
                        return float(winding_list[0])  # voltage is first element
                except (ValueError, TypeError, SyntaxError):
                    continue
        
        return None


class FrequencyCoherenceAnalyzer(BaseElectricalAnalyzer):
    """Analyzes frequency coherence across all blocks."""
    
    def __init__(self):
        super().__init__("frequency_coherence")
    
    def analyze(self, system_graph) -> AnalysisResult:
        """Analyze frequency coherence across all blocks."""
        violations = self._perform_frequency_analysis(system_graph)
        num_violations = len(violations)
        
        # Calculate score based on violations
        score = max(0.0, 1.0 - (num_violations * 0.5))
        status = "PASS" if num_violations == 0 else "FAIL"
        
        return self._create_result(
            status=status,
            violations=violations,
            score=score,
            details={
                "total_violations": num_violations,
                "frequency_analysis": self._analyze_frequencies(system_graph)
            },
            summary=f"Frequency coherence: {num_violations} violations detected, score: {score:.3f}"
        )
    
    def _perform_frequency_analysis(self, system_graph) -> List[str]:
        """Perform frequency coherence analysis and return violations."""
        violations = []
        frequencies = {}  # block_name -> frequency_value
        
        for block_name, block in system_graph.blocks.items():
            frequency = self._extract_frequency(block)
            if frequency is not None:
                frequencies[block_name] = frequency
        
        if not frequencies:
            return []  # No frequency information found
        
        # Find the most common frequency
        freq_values = list(frequencies.values())
        unique_frequencies = list(set(freq_values))
        
        if len(unique_frequencies) > 1:
            # Multiple different frequencies found
            freq_counts = {freq: freq_values.count(freq) for freq in unique_frequencies}
            most_common_freq = max(freq_counts.keys(), key=lambda x: freq_counts[x])
            
            for block_name, freq in frequencies.items():
                if freq != most_common_freq:
                    violations.append(
                        f"Block '{block_name}' has frequency {freq} Hz, but system frequency should be {most_common_freq} Hz"
                    )
        
        return violations
    
    def _analyze_frequencies(self, system_graph) -> Dict[str, Any]:
        """Analyze frequency values across blocks."""
        try:
            frequencies = {}
            for block_name, block in system_graph.blocks.items():
                frequency = self._extract_frequency(block)
                if frequency is not None:
                    frequencies[block_name] = frequency
            
            return {
                "blocks_with_frequency": len(frequencies),
                "frequency_values": frequencies,
                "unique_frequencies": list(set(frequencies.values())) if frequencies else []
            }
        except Exception as e:
            logger.warning(f"Error analyzing frequencies: {e}")
            return {}
    
    def _extract_frequency(self, block) -> Optional[float]:
        """Extract frequency value from a block's parameters."""
        params = block.params
        
        # Different blocks have different frequency parameter names
        frequency_keys = [
            "Frequency (Hz)",
            "Nominal frequency fn (Hz)", 
            "Frequency used for rlc specification (Hz)"
        ]
        
        for key in frequency_keys:
            if key in params:
                try:
                    return float(params[key])
                except (ValueError, TypeError):
                    continue
        
        # Handle transformer frequency (in array format)
        if "Nominal power and frequency [Pn(VA), fn(Hz)]" in params:
            try:
                power_freq_str = params["Nominal power and frequency [Pn(VA), fn(Hz)]"]
                # Parse array format like "[250e6, 60]"
                import ast
                power_freq_list = ast.literal_eval(power_freq_str)
                if isinstance(power_freq_list, list) and len(power_freq_list) >= 2:
                    return float(power_freq_list[1])  # frequency is second element
            except (ValueError, TypeError, SyntaxError):
                pass
        
        return None


class PortConnectivityAnalyzer(BaseElectricalAnalyzer):
    """Analyzes port connectivity effectiveness for blocks."""
    
    def __init__(self):
        super().__init__("port_connectivity")
    
    def analyze(self, system_graph) -> AnalysisResult:
        """Analyze port connectivity effectiveness across all blocks."""
        try:
            # Import block evaluators
            from .block_evaluators import BlockConnectivityReward
            
            # Create evaluator and run analysis
            evaluator = BlockConnectivityReward()
            results = evaluator.evaluate_system(system_graph)
            
            # Calculate overall metrics
            total_blocks = len(results)
            effective_blocks = sum(1 for r in results.values() if r.is_effective)
            avg_score = sum(r.connectivity_score for r in results.values()) / total_blocks if total_blocks > 0 else 0.0
            
            # Determine status
            effectiveness_ratio = effective_blocks / total_blocks if total_blocks > 0 else 0.0
            if effectiveness_ratio >= 0.8:
                status = "PASS"
            elif effectiveness_ratio >= 0.5:
                status = "WARNING"
            else:
                status = "FAIL"
            
            # Create violations list
            violations = []
            for result in results.values():
                if not result.is_effective:
                    unconnected_count = len(result.unconnected_ports)
                    violations.append(
                        f"Block '{result.block_name}' ({result.block_type}): {unconnected_count} unconnected ports"
                    )
            
            return self._create_result(
                status=status,
                violations=violations,
                score=avg_score,
                details={
                    "total_blocks": total_blocks,
                    "effective_blocks": effective_blocks,
                    "effectiveness_ratio": effectiveness_ratio,
                    "block_results": {name: {
                        "score": r.connectivity_score,
                        "effective": r.is_effective,
                        "unconnected_ports": len(r.unconnected_ports)
                    } for name, r in results.items()},
                    "connectivity_summary": evaluator.get_connectivity_summary(results)
                },
                summary=f"Port connectivity: {effective_blocks}/{total_blocks} blocks effective, avg score: {avg_score:.3f}"
            )
            
        except Exception as e:
            logger.error(f"Error in port connectivity analysis: {e}")
            return self._create_result(
                status="ERROR",
                violations=[f"Analysis error: {str(e)}"],
                score=0.0,
                details={"error": str(e)},
                summary=f"Port connectivity analysis failed: {str(e)}"
            )


class ConnectivityAnalyzer(BaseElectricalAnalyzer):
    """Analyzes electrical connectivity between generators and loads."""
    
    def __init__(self):
        super().__init__("connectivity")
    
    def analyze(self, system_graph) -> AnalysisResult:
        """Analyze electrical connectivity between generators and loads."""
        connectivity_result = system_graph.analyze_connectivity()
        
        score = connectivity_result.connectivity_ratio
        isolated_loads = len(connectivity_result.isolated_load_names)
        
        status = "PASS" if isolated_loads == 0 else "FAIL"
        violations = [f"Load '{load}' is electrically isolated" for load in connectivity_result.isolated_load_names]
        
        return self._create_result(
            status=status,
            violations=violations,
            score=score,
            details={
                "total_generators": connectivity_result.total_generators,
                "total_loads": connectivity_result.total_loads,
                "connected_loads": connectivity_result.connected_loads,
                "isolated_loads": connectivity_result.isolated_load_names,
                "paths_found": connectivity_result.paths_found,
                "connectivity_ratio": connectivity_result.connectivity_ratio
            },
            summary=f"Connectivity: {connectivity_result.connected_loads}/{connectivity_result.total_loads} loads connected, score: {score:.3f}"
        )


class SystemCompletenessAnalyzer(BaseElectricalAnalyzer):
    """Analyzes system completeness by analyzing isolated blocks."""
    
    def __init__(self):
        super().__init__("system_completeness")
    
    def analyze(self, system_graph) -> AnalysisResult:
        """Analyze system completeness by analyzing isolated blocks."""
        all_blocks = system_graph.blocks
        total_blocks = len(all_blocks)
        
        if total_blocks == 0:
            return self._create_result(
                status="FAIL",
                violations=["No blocks found in system"],
                score=0.0,
                details={"total_blocks": 0, "isolated_blocks": 0, "connection_ratio": 0.0},
                summary="System completeness: No blocks found, score: 0.000"
            )
        
        # Find isolated blocks (blocks with no connections)
        isolated_blocks = []
        connected_blocks = 0
        
        for block in all_blocks.values():
            has_connections = self._check_block_connections(block, system_graph)
            if has_connections:
                connected_blocks += 1
            else:
                isolated_blocks.append(block.name)
        
        num_isolated = len(isolated_blocks)
        connection_ratio = connected_blocks / total_blocks if total_blocks > 0 else 0.0
        score = max(0.0, min(1.0, connection_ratio))
        
        # Determine status
        if num_isolated == 0:
            status = "PASS"
        elif num_isolated <= total_blocks * 0.2:  # Less than 20% isolated
            status = "WARNING"
        else:
            status = "FAIL"
        
        # Create violations list
        violations = []
        if num_isolated > 0:
            violations.append(f"{num_isolated} blocks are isolated (no electrical connections)")
            # Show individual block names if not too many
            if num_isolated <= 5:
                violations.extend([f"Isolated block: '{block}'" for block in isolated_blocks])
            else:
                violations.extend([f"Isolated block: '{block}'" for block in isolated_blocks[:3]])
                violations.append(f"... and {num_isolated - 3} more isolated blocks")
        
        return self._create_result(
            status=status,
            violations=violations,
            score=score,
            details={
                "total_blocks": total_blocks,
                "connected_blocks": connected_blocks,
                "isolated_blocks": num_isolated,
                "isolated_block_names": isolated_blocks,
                "connection_ratio": connection_ratio
            },
            summary=f"System completeness: {connected_blocks}/{total_blocks} blocks connected, score: {score:.3f}"
        )
    
    def _check_block_connections(self, block, system_graph) -> bool:
        """Check if a block has any electrical connections."""
        # Check if block has any connections by examining the graph
        for edge in system_graph.graph.edges(data=True):
            if edge[2].get("edge_type") != "connects_to":
                continue
            
            src_port_id = edge[0]
            tgt_port_id = edge[1]
            
            # Get block names for source and target ports
            src_data = system_graph.graph.nodes.get(src_port_id, {})
            tgt_data = system_graph.graph.nodes.get(tgt_port_id, {})
            
            src_block = src_data.get("block_name")
            tgt_block = tgt_data.get("block_name")
            
            if src_block == block.name or tgt_block == block.name:
                return True
        
        return False


class ParameterAnalyzer(BaseElectricalAnalyzer):
    """Analyzes block parameters."""
    
    def __init__(self):
        super().__init__("parameter_analysis")
    
    def analyze(self, system_graph) -> AnalysisResult:
        """Analyze parameter validation issues."""
        param_issues = getattr(system_graph, 'param_validation_issues', [])
        num_issues = len(param_issues)
        
        score = max(0.0, 1.0 - (num_issues * 0.1))
        status = "PASS" if num_issues == 0 else "WARNING"
        violations = [f"Block '{issue.block_name}': Invalid key '{issue.invalid_key}'" for issue in param_issues]
        
        return self._create_result(
            status=status,
            violations=violations,
            score=score,
            details={
                "total_issues": num_issues,
                "issues_by_block": {issue.block_name: issue.invalid_key for issue in param_issues}
            },
            summary=f"Parameter analysis: {num_issues} issues detected, score: {score:.3f}"
        )


class BasicValidationAnalyzer(BaseElectricalAnalyzer):
    """Analyzes basic graph validation (errors, warnings, unconnected ports)."""
    
    def __init__(self):
        super().__init__("basic_validation")
    
    def analyze(self, system_graph) -> AnalysisResult:
        """Analyze basic graph validation issues."""
        errors = len(getattr(system_graph, 'validation_errors', []))
        warnings = len(getattr(system_graph, 'validation_warnings', []))
        unconnected = len(getattr(system_graph, 'unconnected_ports', []))
        
        # Weight factors: errors are most severe, warnings moderate, unconnected ports mild
        penalty = errors + 0.5 * warnings + 0.2 * unconnected
        score = max(0.0, min(1.0, 1.0 / (1.0 + penalty)))
        
        status = "PASS" if errors == 0 else "FAIL"
        violations = []
        
        if errors > 0:
            violations.append(f"{errors} validation errors found")
        if warnings > 0:
            violations.append(f"{warnings} validation warnings found")
        if unconnected > 0:
            violations.append(f"{unconnected} unconnected ports found")
        
        return self._create_result(
            status=status,
            violations=violations,
            score=score,
            details={
                "errors": errors,
                "warnings": warnings,
                "unconnected_ports": unconnected,
                "penalty": penalty
            },
            summary=f"Basic validation: {errors} errors, {warnings} warnings, {unconnected} unconnected ports, score: {score:.3f}"
        )


class ElectricalAnalysisSuite:
    """Comprehensive electrical analysis suite."""
    
    def __init__(self, voltage_threshold: float = 0.2):
        """Initialize the analysis suite with optional configuration."""
        self.analyzers = {
            "voltage": VoltageCoherenceAnalyzer(voltage_threshold),
            "frequency": FrequencyCoherenceAnalyzer(),
            "connectivity": ConnectivityAnalyzer(),
            "port_connectivity": PortConnectivityAnalyzer(),
            "completeness": SystemCompletenessAnalyzer(),
            "parameters": ParameterAnalyzer(),
            "validation": BasicValidationAnalyzer(),
        }
    
    def run_analysis(self, system_graph, 
                    checks: Optional[List[str]] = None) -> Dict[str, AnalysisResult]:
        """
        Run analysis checks on the system graph.
        
        Args:
            system_graph: The system graph to analyze
            checks: List of checks to perform. If None, runs all checks.
        
        Returns:
            Dictionary mapping check names to their results
        """
        if checks is None:
            checks = list(self.analyzers.keys())
        
        results = {}
        for check_name in checks:
            if check_name in self.analyzers:
                try:
                    results[check_name] = self.analyzers[check_name].analyze(system_graph)
                except Exception as e:
                    logger.error(f"Error running {check_name} analysis: {e}")
                    # Create error result
                    results[check_name] = AnalysisResult(
                        analyzer_name=check_name,
                        status="ERROR",
                        violations=[f"Analysis error: {str(e)}"],
                        score=0.0,
                        details={"error": str(e)},
                        summary=f"{check_name} analysis failed: {str(e)}"
                    )
            else:
                logger.warning(f"Unknown analysis check: {check_name}")
        
        return results
    
    def get_overall_score(self, results: Dict[str, AnalysisResult], 
                         weights: Optional[Dict[str, float]] = None) -> float:
        """
        Calculate overall score from analysis results.
        
        Args:
            results: Dictionary of analysis results
            weights: Optional weights for each analysis type
        
        Returns:
            Overall score (0.0 to 1.0)
        """
        if not results:
            return 0.0
        
        if weights is None:
            # Equal weights for all analyses
            weights = {name: 1.0 for name in results.keys()}
        
        total_score = 0.0
        total_weight = 0.0
        
        for name, result in results.items():
            weight = weights.get(name, 0.0)
            if weight > 0:
                total_score += result.score * weight
                total_weight += weight
        
        return total_score / total_weight if total_weight > 0 else 0.0
    
    def format_results(self, results: Dict[str, AnalysisResult]) -> str:
        """Format analysis results for display."""
        if not results:
            return "No analysis results available."
        
        output_lines = ["=== ELECTRICAL ANALYSIS RESULTS ===", ""]
        
        for check_name, result in results.items():
            output_lines.append(f"🔍 {check_name.upper()}:")
            output_lines.append(f"   Status: {result.status}")
            output_lines.append(f"   Score: {result.score:.3f}")
            output_lines.append(f"   Summary: {result.summary}")
            
            if result.violations:
                output_lines.append("   Violations:")
                for violation in result.violations[:5]:  # Limit to first 5
                    output_lines.append(f"     • {violation}")
                if len(result.violations) > 5:
                    output_lines.append(f"     ... and {len(result.violations) - 5} more")
            else:
                output_lines.append("   ✅ No violations detected")
            
            output_lines.append("")
        
        # Overall summary
        total_checks = len(results)
        passed_checks = sum(1 for r in results.values() if r.status == "PASS")
        avg_score = sum(r.score for r in results.values()) / total_checks if total_checks > 0 else 0
        
        output_lines.append("=== OVERALL SUMMARY ===")
        output_lines.append(f"Checks passed: {passed_checks}/{total_checks}")
        output_lines.append(f"Average score: {avg_score:.3f}")
        
        return "\n".join(output_lines)
    
    def add_analyzer(self, name: str, analyzer: BaseElectricalAnalyzer):
        """Add a custom analyzer to the suite."""
        self.analyzers[name] = analyzer
    
    def remove_analyzer(self, name: str):
        """Remove an analyzer from the suite."""
        if name in self.analyzers:
            del self.analyzers[name]
    
    def get_analyzer_names(self) -> List[str]:
        """Get list of available analyzer names."""
        return list(self.analyzers.keys()) 