"""
DANCE-ST Unified Neural Model Training for CMAPSS Datasets

This script trains a single unified neural model for all CMAPSS turbofan engine datasets:
- FD001: Single operating condition, single failure mode
- FD002: Six operating conditions, single failure mode
- FD003: Single operating condition, two failure modes
- FD004: Six operating conditions, two failure modes

The trained unified model is saved to the models/saved directory.
"""

import os
import sys
import argparse
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
import joblib
from datetime import datetime
import logging
import matplotlib.pyplot as plt
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

logger = logging.getLogger("DANCEST.CMAPSS.UnifiedTraining")

# Add parent directory to path to ensure imports work correctly
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

def load_cmapss_dataset(dataset_id, data_dir=None):
    """
    Load a single CMAPSS dataset.
    
    Args:
        dataset_id: Dataset ID (FD001, FD002, FD003, or FD004)
        data_dir: Optional explicit path to data directory
    
    Returns:
        DataFrame with loaded data
    """
    logger.info(f"Loading CMAPSS {dataset_id} dataset")
    
    # Try multiple possible data locations
    possible_locations = []
    
    # If data_dir was provided, check it first
    if data_dir:
        possible_locations.append(Path(data_dir))
    
    # Add standard locations
    possible_locations.extend([
        Path("data/@CMAPSS"),
        Path("data/CMAPSS"),
        Path("data"),
        Path("DANCEST_model/data/CMAPSS"),
        Path("DANCEST_model/data/CMAPSSData"),  # Add the actual data location
        Path("DANCEST_model/data"),
        Path("../data/CMAPSS"),
        Path("../data"),
        Path(".")  # Current directory
    ])
    
    # Find the first location that has the required file
    train_file_name = f"train_{dataset_id}.txt"
    dataset_dir = None
    train_file = None
    
    for location in possible_locations:
        # Check if the directory exists
        if not location.exists():
            continue
            
        # Check for the file directly in this directory
        if (location / train_file_name).exists():
            dataset_dir = location
            train_file = location / train_file_name
            logger.info(f"Found dataset files in: {dataset_dir}")
            break
            
        # Also check subdirectories one level down
        for subdir in location.iterdir():
            if subdir.is_dir() and (subdir / train_file_name).exists():
                dataset_dir = subdir
                train_file = subdir / train_file_name
                logger.info(f"Found dataset files in: {dataset_dir}")
                break
        
        if train_file:
            break
    
    if train_file is None:
        checked_paths = "\n  - ".join([str(p) for p in possible_locations])
        error_msg = f"Could not find {train_file_name} in any of these locations:\n  - {checked_paths}"
        logger.error(error_msg)
        raise FileNotFoundError(error_msg)
    
    # Define column names for the CMAPSS dataset
    cols = ['unit', 'cycle', 'op_setting_1', 'op_setting_2', 'op_setting_3'] + \
           [f'sensor_{i}' for i in range(1, 22)]
    
    # Load training data
    logger.info(f"Loading training data from: {train_file}")
    train_data = pd.read_csv(train_file, sep=" ", header=None, names=cols)
    train_data.dropna(axis=1, inplace=True)  # Drop NaN columns
    
    # Add dataset ID as a feature
    train_data['dataset_id'] = dataset_id
    
    # Load RUL data
    rul_file = dataset_dir / f"RUL_{dataset_id}.txt"
    if not rul_file.exists():
        logger.error(f"RUL file {rul_file} not found")
        raise FileNotFoundError(f"RUL file {rul_file} not found")
    
    logger.info(f"Loading RUL data from: {rul_file}")
    rul_data = pd.read_csv(rul_file, sep=" ", header=None, names=['RUL'])
    rul_data.dropna(axis=1, inplace=True)  # Drop NaN columns
    
    # Process training data
    logger.info("Processing training data")
    
    # Group by unit
    grouped_data = train_data.groupby('unit')
    
    # Create RUL values for training data
    max_cycles = grouped_data['cycle'].max()
    
    train_data['RUL'] = train_data.apply(
        lambda row: max_cycles[row['unit']] - row['cycle'],
        axis=1
    )
    
    return train_data

def load_all_cmapss_data(dataset_ids, data_dir=None):
    """
    Load and preprocess data from multiple CMAPSS datasets.
    
    Args:
        dataset_ids: List of dataset IDs to load
        data_dir: Optional explicit path to data directory
    
    Returns:
        X_train, y_train, X_val, y_val, X_test, y_test, scaler, feature_cols
    """
    all_data = []
    
    # Load each dataset
    for dataset_id in dataset_ids:
        try:
            dataset = load_cmapss_dataset(dataset_id, data_dir)
            all_data.append(dataset)
            logger.info(f"Successfully loaded {dataset_id} with {len(dataset)} samples")
        except Exception as e:
            logger.error(f"Error loading {dataset_id}: {e}")
    
    if not all_data:
        raise ValueError("No datasets were successfully loaded")
    
    # Combine all datasets
    combined_data = pd.concat(all_data, ignore_index=True)
    logger.info(f"Combined dataset has {len(combined_data)} samples")
    
    # Create dataset_id one-hot encoding
    for dataset_id in dataset_ids:
        if dataset_id in combined_data['dataset_id'].values:
            combined_data[f'is_{dataset_id}'] = (combined_data['dataset_id'] == dataset_id).astype(int)
    
    # Select features
    # Base features: operating settings and sensor readings
    base_feature_cols = ['op_setting_1', 'op_setting_2', 'op_setting_3'] + \
                   [f'sensor_{i}' for i in range(1, 22) if f'sensor_{i}' in combined_data.columns]
    
    # Add dataset one-hot encoding as features
    dataset_feature_cols = [f'is_{dataset_id}' for dataset_id in dataset_ids 
                           if f'is_{dataset_id}' in combined_data.columns]
    
    feature_cols = base_feature_cols + dataset_feature_cols
    
    # Ensure all feature columns exist
    valid_feature_cols = [col for col in feature_cols if col in combined_data.columns]
    
    X = combined_data[valid_feature_cols].values
    y = combined_data['RUL'].values
    
    # Normalize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Add stratification based on RUL bins and dataset type
    # This helps ensure balanced representation of data across folds
    combined_data['rul_bin'] = pd.qcut(combined_data['RUL'], 10, labels=False, duplicates='drop')
    combined_data['strat_key'] = combined_data['rul_bin'].astype(str) + '_' + combined_data['dataset_id']
    stratify_col = combined_data['strat_key']
    
    # Split into train and validation sets with stratification
    X_train, X_temp, y_train, y_temp, strat_train, strat_temp = train_test_split(
        X_scaled, y, stratify_col, test_size=0.3, random_state=42, stratify=stratify_col
    )
    
    # Split temp into validation and test sets with stratification
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, random_state=42
    )
    
    logger.info(f"Data split: {X_train.shape[0]} training, {X_val.shape[0]} validation, {X_test.shape[0]} test samples")
    
    return X_train, y_train, X_val, y_val, X_test, y_test, scaler, valid_feature_cols, combined_data

def create_unified_model(input_dim, learning_rate=1e-4):
    """
    Create a unified neural model for RUL prediction across all datasets.
    
    Args:
        input_dim: Number of input features
        learning_rate: Learning rate for Adam optimizer
    
    Returns:
        Compiled Keras model
    """
    # Input layer
    inputs = Input(shape=(input_dim,))
    
    # Enhanced architecture with more capacity
    # First block
    x = Dense(512, activation='relu')(inputs)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    
    # Second block
    x = Dense(256, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    
    # Third block
    x = Dense(128, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.25)(x)
    
    # Fourth block
    x = Dense(64, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    # Final block
    x = Dense(32, activation='relu')(x)
    x = Dropout(0.1)(x)
    
    # Output layer (RUL prediction)
    outputs = Dense(1, activation='linear')(x)
    
    # Create and compile model
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='mse',
        metrics=['mae']
    )
    
    return model

def train_unified_model_with_kfold(dataset_ids, n_folds=5, verbose=False, data_dir=None):
    """
    Train a unified neural model using K-fold cross-validation for improved generalization.
    
    Args:
        dataset_ids: List of dataset IDs to include
        n_folds: Number of folds for cross-validation
        verbose: Whether to enable verbose output
        data_dir: Optional path to data directory
    
    Returns:
        best_model, scaler, history, test_mae, results_by_dataset
    """
    logger.info(f"Training unified model with {n_folds}-fold cross-validation for CMAPSS datasets: {', '.join(dataset_ids)}")
    
    # Load and preprocess data from all datasets
    X_train_full, y_train_full, X_val, y_val, X_test, y_test, scaler, feature_cols, combined_data = load_all_cmapss_data(dataset_ids, data_dir)
    
    # Create stratification column
    strat_y = pd.qcut(y_train_full, 10, labels=False, duplicates='drop')
    
    # Initialize K-fold cross-validation
    kf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
    
    # Initialize model storage
    all_models = []
    all_histories = []
    all_scores = []
    fold_metrics = []
    
    # Train with k-fold cross-validation
    for fold, (train_idx, val_idx) in enumerate(kf.split(X_train_full, strat_y)):
        logger.info(f"Training fold {fold+1}/{n_folds}")
        
        # Split data
        X_train_fold = X_train_full[train_idx]
        y_train_fold = y_train_full[train_idx]
        X_val_fold = X_train_full[val_idx]
        y_val_fold = y_train_full[val_idx]
        
        # Create model for this fold
        input_dim = X_train_fold.shape[1]
        model = create_unified_model(input_dim)
        
        # Setup callbacks
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        model_path = f"DANCEST_model/models/saved/cmapss_unified_model_fold{fold+1}_{timestamp}.keras"
        
        early_stopping = EarlyStopping(
            monitor='val_loss',
            patience=20,  # Increased patience
            restore_best_weights=True
        )
        
        checkpoint = ModelCheckpoint(
            model_path,
            monitor='val_loss',
            save_best_only=True,
            verbose=1
        )
        
        reduce_lr = ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=10,
            min_lr=1e-6,
            verbose=1
        )
        
        # Train model
        logger.info(f"Starting training fold {fold+1} with {X_train_fold.shape[0]} samples")
        history = model.fit(
            X_train_fold, y_train_fold,
            validation_data=(X_val_fold, y_val_fold),
            epochs=100,  # Increased epochs
            batch_size=64,  # Smaller batch size
            callbacks=[early_stopping, checkpoint, reduce_lr],
            verbose=1 if verbose else 2
        )
        
        # Evaluate on validation data
        val_loss, val_mae = model.evaluate(X_val, y_val, verbose=0)
        logger.info(f"Fold {fold+1} validation MAE: {val_mae:.2f}")
        
        # Save model info
        all_models.append(model)
        all_histories.append(history)
        all_scores.append(val_mae)
        
        fold_metrics.append({
            'fold': fold + 1,
            'val_mae': float(val_mae),
            'val_loss': float(val_loss),
            'model_path': model_path,
            'n_epochs': len(history.history['loss'])
        })
    
    # Select best model based on validation MAE
    best_idx = np.argmin(all_scores)
    best_model = all_models[best_idx]
    best_history = all_histories[best_idx]
    
    logger.info(f"Best model from fold {best_idx+1} with validation MAE: {all_scores[best_idx]:.2f}")
    
    # Evaluate best model on test data
    test_loss, test_mae = best_model.evaluate(X_test, y_test, verbose=0)
    logger.info(f"Best model test MAE: {test_mae:.2f}")
    
    # Save best model
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    best_model_path = f"DANCEST_model/models/saved/cmapss_unified_model_best_{timestamp}.keras"
    best_model.save(best_model_path)
    
    # Save scaler
    scaler_path = f"DANCEST_model/models/saved/cmapss_unified_scaler_{timestamp}.joblib"
    joblib.dump(scaler, scaler_path)
    logger.info(f"Best model saved to: {best_model_path}")
    logger.info(f"Scaler saved to: {scaler_path}")
    
    # Save feature column names
    feature_cols_path = f"DANCEST_model/models/saved/cmapss_unified_features_{timestamp}.json"
    import json
    with open(feature_cols_path, 'w') as f:
        json.dump(feature_cols, f)
    logger.info(f"Feature columns saved to: {feature_cols_path}")
    
    # Plot training history for best model
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.plot(best_history.history['loss'])
    plt.plot(best_history.history['val_loss'])
    plt.title('Best Model Loss')
    plt.ylabel('Loss (MSE)')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')
    
    plt.subplot(1, 3, 2)
    plt.plot(best_history.history['mae'])
    plt.plot(best_history.history['val_mae'])
    plt.title('Best Model MAE')
    plt.ylabel('MAE')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')
    
    # Plot validation scores across folds
    plt.subplot(1, 3, 3)
    plt.bar(range(1, n_folds+1), all_scores)
    plt.title('Validation MAE by Fold')
    plt.ylabel('MAE')
    plt.xlabel('Fold')
    plt.xticks(range(1, n_folds+1))
    
    # Save plot
    plots_dir = Path("DANCEST_model/results")
    plots_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(plots_dir / f"cmapss_unified_kfold_training_{timestamp}.png")
    if verbose:
        plt.show()
    
    # Evaluate on each dataset separately
    results_by_dataset = {}
    for dataset_id in dataset_ids:
        try:
            # Load this specific dataset
            dataset = load_cmapss_dataset(dataset_id, data_dir)
            
            # Create dataset one-hot encoding
            for ds_id in dataset_ids:
                if f'is_{ds_id}' in feature_cols:
                    dataset[f'is_{ds_id}'] = 1 if ds_id == dataset_id else 0
            
            # Extract features and targets
            X_dataset = dataset[feature_cols].values
            y_dataset = dataset['RUL'].values
            
            # Scale features
            X_dataset_scaled = scaler.transform(X_dataset)
            
            # Evaluate
            dataset_loss, dataset_mae = best_model.evaluate(X_dataset_scaled, y_dataset, verbose=0)
            
            # Make predictions for visualization
            y_pred = best_model.predict(X_dataset_scaled).flatten()
            
            # Sample a few engines for detailed visualization
            unique_units = dataset['unit'].unique()
            sample_units = np.random.choice(unique_units, min(5, len(unique_units)), replace=False)
            
            # Create visualization of predictions vs actual
            plt.figure(figsize=(15, 10))
            for i, unit in enumerate(sample_units):
                unit_data = dataset[dataset['unit'] == unit]
                unit_x = unit_data[feature_cols].values
                unit_x_scaled = scaler.transform(unit_x)
                unit_y_true = unit_data['RUL'].values
                unit_y_pred = best_model.predict(unit_x_scaled).flatten()
                
                plt.subplot(len(sample_units), 1, i+1)
                plt.plot(unit_data['cycle'], unit_y_true, 'b-', label='Actual RUL')
                plt.plot(unit_data['cycle'], unit_y_pred, 'r-', label='Predicted RUL')
                plt.title(f'Dataset {dataset_id} - Engine Unit {unit}')
                plt.xlabel('Cycle')
                plt.ylabel('RUL')
                plt.legend()
                plt.grid(True)
            
            plt.tight_layout()
            plt.savefig(plots_dir / f"cmapss_{dataset_id}_predictions_{timestamp}.png")
            
            results_by_dataset[dataset_id] = {
                'mae': float(dataset_mae),
                'loss': float(dataset_loss)
            }
            logger.info(f"Performance on {dataset_id}: MAE = {dataset_mae:.2f}")
            
        except Exception as e:
            logger.error(f"Error evaluating on {dataset_id}: {e}")
            results_by_dataset[dataset_id] = {'error': str(e)}
    
    # Save fold metrics
    folds_path = f"DANCEST_model/results/cmapss_unified_folds_{timestamp}.json"
    with open(folds_path, 'w') as f:
        json.dump(fold_metrics, f)
    
    # Save performance by dataset
    perf_path = f"DANCEST_model/results/cmapss_unified_performance_{timestamp}.json"
    with open(perf_path, 'w') as f:
        json.dump(results_by_dataset, f)
    logger.info(f"Performance by dataset saved to: {perf_path}")
    
    return best_model, scaler, best_history, test_mae, results_by_dataset

def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Train a unified neural model for all CMAPSS datasets')
    parser.add_argument('--datasets', nargs='+', default=['FD001', 'FD002', 'FD003', 'FD004'],
                        help='CMAPSS datasets to include (default: all)')
    parser.add_argument('--folds', type=int, default=5, 
                        help='Number of folds for cross-validation (default: 5)')
    parser.add_argument('--verbose', action='store_true', help='Enable verbose output')
    parser.add_argument('--data_dir', type=str, default=None, 
                        help='Explicit path to the directory containing CMAPSS data files')
    args = parser.parse_args()
    
    # Check for valid datasets
    valid_datasets = ['FD001', 'FD002', 'FD003', 'FD004']
    datasets = [ds for ds in args.datasets if ds in valid_datasets]
    
    if not datasets:
        logger.error("No valid datasets specified")
        print("Please specify at least one valid dataset: FD001, FD002, FD003, FD004")
        return
    
    logger.info(f"Training unified model for datasets: {', '.join(datasets)}")
    
    # Create necessary directories
    os.makedirs("DANCEST_model/models/saved", exist_ok=True)
    os.makedirs("DANCEST_model/results", exist_ok=True)
    
    # If data_dir was provided, check it first
    if args.data_dir:
        data_dir = Path(args.data_dir)
        if data_dir.exists():
            logger.info(f"Using specified data directory: {data_dir}")
        else:
            logger.warning(f"Specified data directory {data_dir} does not exist")
    
    # Print information about where we're looking for data
    print("\nLooking for CMAPSS data files in these locations:")
    if args.data_dir:
        print(f"  - {args.data_dir} (user-specified)")
    print("  - data/@CMAPSS")
    print("  - data/CMAPSS")
    print("  - data")
    print("  - DANCEST_model/data/CMAPSS")
    print("  - DANCEST_model/data/CMAPSSData")
    print("  - And other standard locations\n")
    
    try:
        # Train unified model with k-fold cross-validation
        model, scaler, history, test_mae, results_by_dataset = train_unified_model_with_kfold(
            datasets, args.folds, args.verbose, args.data_dir
        )
        
        # Print summary
        print("\nTraining Summary:")
        print(f"Unified Model Test MAE: {test_mae:.2f}")
        print("Performance by dataset:")
        for dataset_id, result in results_by_dataset.items():
            if 'error' in result:
                print(f"  {dataset_id}: Failed - {result['error']}")
            else:
                print(f"  {dataset_id}: MAE = {result['mae']:.2f}")
        
        print("\nUnified model saved to: DANCEST_model/models/saved/cmapss_unified_model_best_*.keras")
        
        return model, scaler, history, test_mae, results_by_dataset
        
    except Exception as e:
        logger.error(f"Error training unified model: {e}")
        print(f"Error: {e}")
        return None

if __name__ == "__main__":
    main() 