"""
Core DANCE-ST pipeline components.

This module provides a complete implementation of the three-phase DANCE-ST framework,
orchestrated by the main DANCESTPipeline class. The implementation directly follows
the final methodology described in the paper.
"""
import time
import logging
import numpy as np
import networkx as nx
from typing import Callable, Dict, Tuple, Any, List

# Configure logging for the pipeline module
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("DANCEST.Pipeline")

class ConstraintPotentialExtractor:
    """
    Phase I: Identifies a critical subgraph using Constraint-Potential Diffusion.

    This class implements the training-free methodology from the paper. It takes the
    current system state, calculates a "constraint potential" based on proximity to
    physical limits, and diffuses this potential across the knowledge graph to find
    the most systemically relevant components.
    """
    def __init__(self, k: int = 50, alpha: float = 0.85, diffusion_iters: int = 10):
        """
        Initializes the extractor.

        Args:
            k (int): Number of top nodes to select for the subgraph.
            alpha (float): Damping factor for the diffusion (balances local vs. propagated influence).
            diffusion_iters (int): Number of iterations for the diffusion process.
        """
        self.k = k
        self.alpha = alpha
        self.diffusion_iters = diffusion_iters

    def _calculate_potential(self, graph: nx.DiGraph, system_state: Dict, constraints: Dict) -> Tuple[List, np.ndarray]:
        """Calculates the initial potential vector Φ based on constraint proximity."""
        sorted_nodes = sorted(list(graph.nodes()))
        node_to_idx = {node: i for i, node in enumerate(sorted_nodes)}
        potential_vector = np.zeros(len(sorted_nodes))
        
        # Example constraint: a temperature limit
        temp_limit = constraints.get("temperature_limit", {}).get("value", 1200.0)
        epsilon_pot = 1e-6

        for node_id, value in system_state.items():
            if node_id not in node_to_idx:
                continue
                
            slack = temp_limit - value
            # A node's potential becomes high only if it's close to the limit
            if 0 < slack < (temp_limit * 0.1):  # e.g., within 10% of failure threshold
                idx = node_to_idx[node_id]
                potential_vector[idx] = 1.0 / (slack + epsilon_pot)
        
        if np.sum(potential_vector) > 0:
            potential_vector /= np.sum(potential_vector)  # Normalize for stability
            logger.info(f"Calculated non-zero initial potential (Φ) for {np.count_nonzero(potential_vector)} nodes.")
        
        return sorted_nodes, potential_vector

    def _run_diffusion(self, graph: nx.DiGraph, potential_vector: np.ndarray, sorted_nodes: List) -> np.ndarray:
        """Runs the iterative diffusion process to get final relevance scores Λ."""
        if np.sum(potential_vector) == 0:
            return potential_vector

        # Use the symmetrically normalized adjacency matrix (W in the paper) for stable diffusion
        adj_matrix = nx.adjacency_matrix(graph.to_undirected(), nodelist=sorted_nodes)
        D_inv_sqrt = np.diag(1.0 / np.sqrt(np.asarray(adj_matrix.sum(axis=1)).flatten() + 1e-9))
        W_norm = D_inv_sqrt @ adj_matrix @ D_inv_sqrt
        
        relevance_scores = potential_vector.copy()
        for _ in range(self.diffusion_iters):
            relevance_scores = (1 - self.alpha) * potential_vector + self.alpha * (W_norm @ relevance_scores)
            
        logger.info("Completed diffusion process to get final relevance scores (Λ).")
        return relevance_scores

    def extract(self, graph: nx.DiGraph, system_state: Dict, constraints: Dict) -> nx.DiGraph:
        """Extracts the most relevant subgraph based on the current system state."""
        if graph.number_of_nodes() == 0:
            return graph.copy()

        start_time = time.time()
        
        # 1. Calculate initial potentials (Φ)
        sorted_nodes, potential_vector = self._calculate_potential(graph, system_state, constraints)
        
        # 2. Propagate potentials via diffusion to get final scores (Λ)
        final_relevance = self._run_diffusion(graph, potential_vector, sorted_nodes)
        
        # 3. Select top-k nodes to form the subgraph
        if np.sum(final_relevance) == 0:
            logger.warning("No relevant nodes found after diffusion. Returning empty subgraph.")
            return nx.DiGraph()

        top_k_indices = np.argsort(final_relevance)[-self.k:]
        top_nodes = [sorted_nodes[i] for i in top_k_indices]
        
        subgraph = graph.subgraph(top_nodes).copy()
        
        elapsed = time.time() - start_time
        logger.info(f"Phase I completed in {elapsed:.4f}s: Extracted subgraph with {len(subgraph.nodes())} nodes.")
        return subgraph

class UncertaintyWeightedFusion:
    """Phase II: Fuses neural and symbolic predictions using uncertainty."""
    def __init__(self):
        self.logger = logging.getLogger("DANCEST.Phase2")

    @staticmethod
    def _omega_star(sigma_n2: np.ndarray, sigma_s2: np.ndarray) -> np.ndarray:
        """Calculates the optimal variance-minimizing weight Ω∗ = σ_s² / (σ_n² + σ_s²)."""
        denominator = sigma_n2 + sigma_s2
        # Avoid division by zero: if total variance is zero, weights are equal.
        return np.where(denominator < 1e-9, 0.5, sigma_s2 / denominator)

    def fuse(self, f_n: np.ndarray, f_s: np.ndarray, sigma_n2: np.ndarray, sigma_s2: np.ndarray) -> np.ndarray:
        """Returns the fused prediction f_int."""
        omega = self._omega_star(sigma_n2, sigma_s2)
        f_int = omega * f_n + (1.0 - omega) * f_s
        self.logger.info(f"Phase II completed: Fused predictions with mean neural weight (omega) of {np.mean(omega):.3f}.")
        return f_int

class ConstraintProjection:
    """Phase III: Projects predictions onto a valid, constraint-satisfying space."""
    def __init__(self, constraints: Dict):
        self.constraints = constraints
        self.logger = logging.getLogger("DANCEST.Phase3")
    
    def project(self, f_int: np.ndarray) -> np.ndarray:
        """
        Projects f_int onto the feasible space. For the box constraints in the paper,
        the Douglas-Rachford algorithm converges in one step to simple clipping.
        """
        lower_bound = self.constraints.get("rul_lower_bound", {}).get("value", -np.inf)
        upper_bound = self.constraints.get("temperature_limit", {}).get("value", np.inf)
        
        f_proj = np.clip(f_int, lower_bound, upper_bound)
        
        adjustment = np.sum(np.abs(f_proj - f_int))
        self.logger.info(f"Phase III completed: Projected predictions. Total adjustment: {adjustment:.4f}.")
        return f_proj

class DANCESTPipeline:
    """High-level orchestration of the three DANCE-ST phases."""
    def __init__(
        self,
        graph: nx.DiGraph,
        constraints: Dict,
        neural_estimator: Callable[[List[Any]], Tuple[np.ndarray, np.ndarray]],
        symbolic_estimator: Callable[[List[Any]], Tuple[np.ndarray, np.ndarray]],
    ):
        self.graph = graph
        self.constraints = constraints
        self.neural_estimator = neural_estimator
        self.symbolic_estimator = symbolic_estimator
        
        self.extractor = ConstraintPotentialExtractor()
        self.fusion = UncertaintyWeightedFusion()
        self.projector = ConstraintProjection(self.constraints)

    def predict(self, system_state: Dict) -> Dict[str, Any]:
        """Runs the full DANCE-ST pipeline for a given system state."""
        # Phase I – Extract critical subgraph based on current state
        sub_g = self.extractor.extract(self.graph, system_state, self.constraints)
        
        if sub_g.number_of_nodes() == 0:
            return {"final_prediction": np.array([]), "subgraph_nodes": []}

        # Phase II – Obtain predictions for the subgraph and fuse them
        subgraph_nodes = list(sub_g.nodes)
        f_n, sigma_n2 = self.neural_estimator(subgraph_nodes)
        f_s, sigma_s2 = self.symbolic_estimator(subgraph_nodes)
        f_int = self.fusion.fuse(f_n, f_s, sigma_n2, sigma_s2)

        # Phase III – Project the fused prediction to be physically consistent
        f_proj = self.projector.project(f_int)

        return {"final_prediction": f_proj, "subgraph_nodes": subgraph_nodes}
