import numpy as np
import networkx as nx
import os
import sys
import logging
import argparse
from pathlib import Path
import tensorflow as tf
import joblib
from datetime import datetime
from sklearn.preprocessing import StandardScaler

# Add parent directory to path to ensure imports work correctly
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# Import DANCE-ST components
from DANCEST_model.Core.pipeline import (
    RelevanceDrivenSubgraphExtractor,
    UncertaintyWeightedFusion,
    CausalConsistencyProjection,
    DANCESTPipeline
)

# Import CMAPSS symbolic model
from DANCEST_model.Core.cmapss_symbolic_model import CmapssSymbolicEstimator

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("DANCEST_model/agent_workflow.log"),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger("DANCEST.Runner")

def load_neural_model(dataset="FD001"):
    """
    Load the trained neural model for the specified CMAPSS dataset.
    
    Args:
        dataset: CMAPSS dataset ID (FD001, FD002, FD003, FD004)
    
    Returns:
        model: Loaded Keras model
        scaler: Loaded scaler for data preprocessing
    """
    # Look in models/saved directory for the newest model matching the dataset
    models_dir = Path("DANCEST_model/models/saved")
    
    # Find all model files for this dataset
    model_files = list(models_dir.glob(f"cmapss_{dataset}_model_*.keras"))
    
    if not model_files:
        logger.warning(f"No trained model found for {dataset}. Using fallback model.")
        # Create a simple model as fallback
        inputs = tf.keras.layers.Input(shape=(20,))  # Assuming 20 features after preprocessing
        x = tf.keras.layers.Dense(64, activation='relu')(inputs)
        x = tf.keras.layers.Dense(32, activation='relu')(x)
        outputs = tf.keras.layers.Dense(1)(x)
        model = tf.keras.Model(inputs=inputs, outputs=outputs)
        model.compile(optimizer='adam', loss='mse')
        
        # Create a dummy scaler
        scaler = StandardScaler()
        
        return model, scaler
    
    # Sort by modification time to get the newest
    newest_model = sorted(model_files, key=os.path.getmtime)[-1]
    logger.info(f"Loading neural model: {newest_model}")
    
    # Load the model
    model = tf.keras.models.load_model(str(newest_model))
    
    # Find the matching scaler
    scaler_name = str(newest_model).replace("model", "scaler").replace(".keras", ".joblib")
    scaler_path = Path(scaler_name)
    
    if not scaler_path.exists():
        logger.warning(f"Scaler not found at {scaler_path}, looking for any scaler for this dataset")
        scaler_files = list(models_dir.glob(f"cmapss_{dataset}_scaler_*.joblib"))
        if scaler_files:
            scaler_path = sorted(scaler_files, key=os.path.getmtime)[-1]
    
    if scaler_path.exists():
        logger.info(f"Loading scaler: {scaler_path}")
        scaler = joblib.load(scaler_path)
    else:
        logger.warning("No scaler found. Using StandardScaler with default parameters.")
        scaler = StandardScaler()
    
    return model, scaler

def create_knowledge_graph(dataset="FD001"):
    """
    Create a knowledge graph for the CMAPSS dataset.
    
    For CMAPSS, we create a simple graph where:
    - Nodes represent engine components
    - Edges represent causal relationships between components
    
    Args:
        dataset: CMAPSS dataset ID (FD001, FD002, FD003, FD004)
    
    Returns:
        G: NetworkX DiGraph representing the knowledge graph
    """
    G = nx.DiGraph()
    
    # Define components based on CMAPSS sensors
    components = [
        {"id": "fan", "name": "Fan", "type": "component"},
        {"id": "lpc", "name": "Low Pressure Compressor", "type": "component"},
        {"id": "hpc", "name": "High Pressure Compressor", "type": "component"},
        {"id": "lpt", "name": "Low Pressure Turbine", "type": "component"},
        {"id": "hpt", "name": "High Pressure Turbine", "type": "component"},
        {"id": "combustor", "name": "Combustor", "type": "component"},
        {"id": "nozzle", "name": "Nozzle", "type": "component"},
        {"id": "shaft", "name": "Main Shaft", "type": "component"},
        {"id": "bearing1", "name": "Bearing 1", "type": "component"},
        {"id": "bearing2", "name": "Bearing 2", "type": "component"},
        {"id": "seal1", "name": "Seal 1", "type": "component"},
        {"id": "seal2", "name": "Seal 2", "type": "component"},
    ]
    
    # Add nodes
    for component in components:
        G.add_node(component["id"], **component)
    
    # Add edges based on causal relationships
    edges = [
        ("fan", "lpc", {"relationship": "drives"}),
        ("lpc", "hpc", {"relationship": "drives"}),
        ("hpc", "combustor", {"relationship": "provides_air"}),
        ("combustor", "hpt", {"relationship": "drives"}),
        ("hpt", "lpt", {"relationship": "drives"}),
        ("lpt", "nozzle", {"relationship": "drives"}),
        ("shaft", "fan", {"relationship": "connects"}),
        ("shaft", "lpc", {"relationship": "connects"}),
        ("shaft", "hpc", {"relationship": "connects"}),
        ("shaft", "hpt", {"relationship": "connects"}),
        ("shaft", "lpt", {"relationship": "connects"}),
        ("bearing1", "shaft", {"relationship": "supports"}),
        ("bearing2", "shaft", {"relationship": "supports"}),
        ("seal1", "hpc", {"relationship": "seals"}),
        ("seal2", "hpt", {"relationship": "seals"}),
    ]
    
    # Add dataset-specific edges
    if dataset in ["FD002", "FD004"]:  # Multiple operating conditions
        # Add more complex interactions for multi-condition datasets
        additional_edges = [
            ("fan", "bearing1", {"relationship": "stresses"}),
            ("lpt", "bearing2", {"relationship": "stresses"}),
            ("hpc", "seal1", {"relationship": "wears"}),
            ("hpt", "seal2", {"relationship": "wears"}),
        ]
        edges.extend(additional_edges)
    
    if dataset in ["FD003", "FD004"]:  # Multiple failure modes
        # Add edges for HPC degradation failure mode
        hpc_edges = [
            ("hpc", "combustor", {"relationship": "affects_efficiency"}),
            ("hpc", "shaft", {"relationship": "increases_vibration"}),
        ]
        edges.extend(hpc_edges)
    
    # Add all edges to the graph
    for src, dst, attrs in edges:
        G.add_edge(src, dst, **attrs)
    
    logger.info(f"Created knowledge graph for CMAPSS {dataset} with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
    return G

def create_neural_estimator(model, scaler):
    """
    Create a neural estimator function that can be used by the DANCE-ST pipeline.
    
    Args:
        model: Trained Keras model
        scaler: Fitted scaler for preprocessing
    
    Returns:
        estimator_fn: Function that takes vertices and time as input and returns predictions and variances
    """
    def neural_estimator(vertices, t):
        """
        Neural estimator function for RUL prediction.
        
        Args:
            vertices: Input vertices from the knowledge graph
            t: Time point (cycle number)
            
        Returns:
            pred: Predicted RUL values
            var: Variance (uncertainty) in predictions
        """
        # For CMAPSS, we create synthetic features based on vertex IDs and time
        n_vertices = len(vertices)
        
        # Create feature array - in a real system, this would use actual sensor data
        # For demonstration, we create synthetic features based on component types
        features = np.zeros((n_vertices, 20))  # Assuming 20 features after preprocessing
        
        for i, vertex in enumerate(vertices):
            # For demonstration, we'll just use random features influenced by vertex ID and time
            vertex_id = vertex if isinstance(vertex, str) else vertex.get('id', 'unknown')
            seed = sum([ord(c) for c in vertex_id]) + int(t)
            np.random.seed(seed)
            
            # Generate synthetic features that depend on time and vertex
            raw_features = np.random.randn(20) * 0.1
            
            # Add time-based degradation pattern
            time_factor = t / 100.0
            if 'fan' in vertex_id or 'compressor' in vertex_id:
                # Fan and compressor components degrade faster
                raw_features[:5] += np.array([0.5, 0.3, 0.2, 0.4, 0.6]) * time_factor
            elif 'turbine' in vertex_id or 'combustor' in vertex_id:
                # Turbine and combustor components show different patterns
                raw_features[5:10] += np.array([0.4, 0.6, 0.5, 0.3, 0.2]) * time_factor
            elif 'bearing' in vertex_id or 'seal' in vertex_id:
                # Bearings and seals show another pattern
                raw_features[10:15] += np.array([0.7, 0.2, 0.5, 0.4, 0.3]) * time_factor
            else:
                # Other components
                raw_features[15:] += np.array([0.2, 0.3, 0.4, 0.5, 0.3]) * time_factor
            
            features[i] = raw_features
        
        # Scale features
        try:
            features_scaled = scaler.transform(features)
        except:
            # If scaler fails, use features as is but normalize
            features_scaled = features / np.abs(features).max(axis=0, keepdims=True)
        
        # Get predictions from model
        # Use Monte Carlo Dropout for uncertainty estimation if available
        try:
            # Try Monte Carlo Dropout approach (multiple forward passes)
            n_samples = 10
            mc_preds = []
            
            for _ in range(n_samples):
                mc_preds.append(model(features_scaled, training=True).numpy())
            
            mc_preds = np.array(mc_preds)
            pred = np.mean(mc_preds, axis=0).flatten()
            var = np.var(mc_preds, axis=0).flatten()
            
            # Add uncertainty floor to prevent overconfidence
            var = np.maximum(var, 5.0)
            
        except:
            # Fallback to single prediction
            pred = model.predict(features_scaled, verbose=0).flatten()
            
            # Generate uncertainty based on time
            # Later predictions have higher uncertainty
            base_uncertainty = 10.0
            time_factor = 1.0 + (float(t) / 100.0)
            var = np.full_like(pred, base_uncertainty * time_factor)
        
        return pred, var
    
    return neural_estimator

def create_relevance_functions():
    """
    Create relevance functions for the CMAPSS dataset.
    
    Returns:
        causal_fn: Function to compute causal relevance
        spatial_fn: Function to compute spatial relevance
        temporal_fn: Function to compute temporal relevance
    """
    def causal_relevance(vertex):
        """Compute causal relevance based on vertex type."""
        vertex_id = vertex if isinstance(vertex, str) else vertex.get('id', 'unknown')
        
        # Critical components have higher causal relevance
        if 'hpc' in vertex_id or 'hpt' in vertex_id or 'combustor' in vertex_id:
            return 0.9  # High relevance for core components
        elif 'lpc' in vertex_id or 'lpt' in vertex_id or 'fan' in vertex_id:
            return 0.7  # Medium-high relevance
        elif 'bearing' in vertex_id or 'seal' in vertex_id:
            return 0.6  # Medium relevance
        else:
            return 0.4  # Lower relevance for other components
    
    def spatial_relevance(vertex, spatial_coords):
        """Compute spatial relevance based on component location."""
        vertex_id = vertex if isinstance(vertex, str) else vertex.get('id', 'unknown')
        
        # Simplified spatial model based on engine layout
        # In a real implementation, this would use actual spatial coordinates
        if 'fan' in vertex_id:
            return 0.8 if spatial_coords[0] < 0.3 else 0.4  # Fan is in front
        elif 'compressor' in vertex_id:
            return 0.8 if 0.2 < spatial_coords[0] < 0.5 else 0.4  # Compressors in middle-front
        elif 'combustor' in vertex_id:
            return 0.8 if 0.4 < spatial_coords[0] < 0.6 else 0.4  # Combustor in middle
        elif 'turbine' in vertex_id:
            return 0.8 if 0.5 < spatial_coords[0] < 0.8 else 0.4  # Turbines in middle-back
        elif 'nozzle' in vertex_id:
            return 0.8 if spatial_coords[0] > 0.7 else 0.4  # Nozzle in back
        else:
            return 0.5  # Neutral relevance for other components
    
    def temporal_relevance(vertex, t):
        """Compute temporal relevance based on time in lifecycle."""
        vertex_id = vertex if isinstance(vertex, str) else vertex.get('id', 'unknown')
        
        # Convert t to a float value
        t_val = float(t)
        
        # Early lifecycle: bearings and seals are more relevant
        if t_val < 50:
            if 'bearing' in vertex_id or 'seal' in vertex_id:
                return 0.8
            else:
                return 0.5
        # Mid lifecycle: compressors and turbines become more relevant
        elif 50 <= t_val < 150:
            if 'compressor' in vertex_id or 'turbine' in vertex_id:
                return 0.9
            else:
                return 0.6
        # Late lifecycle: combustor and HPT most critical
        else:
            if 'combustor' in vertex_id or 'hpt' in vertex_id:
                return 0.95
            elif 'hpc' in vertex_id or 'lpt' in vertex_id:
                return 0.85
            else:
                return 0.7
    
    return causal_relevance, spatial_relevance, temporal_relevance

def run_dancest_pipeline(dataset="FD001", start_t=0, end_t=250, dt=50, k=50, alpha=0.4, beta=0.3, gamma=0.3):
    """
    Run the DANCEST pipeline on CMAPSS data.
    
    Args:
        dataset: CMAPSS dataset ID (FD001, FD002, FD003, FD004)
        start_t: Starting time point (cycle)
        end_t: Ending time point (cycle)
        dt: Time step between predictions
        k: Number of vertices to include in subgraph
        alpha: Weight for causal relevance
        beta: Weight for spatial relevance
        gamma: Weight for temporal relevance
    
    Returns:
        results: Dictionary of predictions at each time point
    """
    logger.info(f"Starting DANCEST pipeline for CMAPSS {dataset}")
    
    # Load neural model
    model, scaler = load_neural_model(dataset)
    
    # Create knowledge graph
    graph = create_knowledge_graph(dataset)
    
    # Create symbolic estimator
    symbolic_estimator = CmapssSymbolicEstimator(dataset)
    
    # Create neural estimator function
    neural_estimator_fn = create_neural_estimator(model, scaler)
    
    # Create relevance functions
    causal_fn, spatial_fn, temporal_fn = create_relevance_functions()
    
    # Create DANCEST components
    extractor = RelevanceDrivenSubgraphExtractor(
        k=k,
        alpha=alpha,
        beta=beta,
        gamma=gamma,
        causal_fn=causal_fn,
        spatial_fn=spatial_fn,
        temporal_fn=temporal_fn,
        auto_optimize_weights=True
    )
    
    fusion = UncertaintyWeightedFusion(tau=0.1)
    
    projector = CausalConsistencyProjection(
        lower_bound=0.0,  # RUL cannot be negative
        upper_bound=130.0,  # Maximum RUL value
        eta=0.5,
        max_iter=10
    )
    
    # Create DANCEST pipeline
    pipeline = DANCESTPipeline(
        graph=graph,
        neural_estimator=neural_estimator_fn,
        symbolic_estimator=symbolic_estimator.predict,
        extractor=extractor,
        fusion=fusion,
        projector=projector
    )
    
    # Create spatial points for prediction
    # For demonstration, we create a linear grid along the engine
    n_spatial_points = k  # Set to match k (number of vertices in subgraph)
    spatial_coords = np.linspace(0, 1, n_spatial_points).reshape(-1, 1)
    
    # Create time points for prediction
    time_points = np.arange(start_t, end_t + dt, dt)
    
    # Run pipeline
    logger.info(f"Running predictions for time points: {time_points}")
    predictions = pipeline.predict(spatial_coords, time_points)
    
    # Create results directory
    results_dir = Path("DANCEST_model/results")
    results_dir.mkdir(parents=True, exist_ok=True)
    
    # Save results to file
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    results_file = results_dir / f"cmapss_{dataset}_predictions_{timestamp}.npz"
    
    np.savez(
        results_file,
        spatial_coords=spatial_coords,
        time_points=time_points,
        predictions={str(t): p for t, p in predictions.items()}
    )
    
    logger.info(f"Results saved to {results_file}")
    
    # Print summary of results
    logger.info("Prediction Summary:")
    for t, preds in predictions.items():
        logger.info(f"  Time {t}: mean RUL = {preds.mean():.2f}, std = {preds.std():.2f}")
    
    return predictions

if __name__ == "__main__":
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Run DANCEST pipeline on CMAPSS dataset')
    parser.add_argument('--dataset', type=str, default='FD001', choices=['FD001', 'FD002', 'FD003', 'FD004'],
                      help='CMAPSS dataset to use (default: FD001)')
    parser.add_argument('--start_t', type=int, default=0, help='Starting time point (cycle)')
    parser.add_argument('--end_t', type=int, default=250, help='Ending time point (cycle)')
    parser.add_argument('--dt', type=int, default=50, help='Time step between predictions')
    parser.add_argument('--k', type=int, default=50, help='Number of vertices in subgraph')
    parser.add_argument('--alpha', type=float, default=0.4, help='Weight for causal relevance')
    parser.add_argument('--beta', type=float, default=0.3, help='Weight for spatial relevance')
    parser.add_argument('--gamma', type=float, default=0.3, help='Weight for temporal relevance')
    
    args = parser.parse_args()
    
    # Validate weights sum to 1
    if not np.isclose(args.alpha + args.beta + args.gamma, 1.0):
        parser.error("Weights (alpha, beta, gamma) must sum to 1.0")
    
    # Run DANCEST pipeline
    run_dancest_pipeline(
        dataset=args.dataset,
        start_t=args.start_t,
        end_t=args.end_t,
        dt=args.dt,
        k=args.k,
        alpha=args.alpha,
        beta=args.beta,
        gamma=args.gamma
    ) 