"""
DANCE-ST Neural Model Training for CMAPSS Datasets

This script trains neural models for the CMAPSS turbofan engine datasets:
- FD001: Single operating condition, single failure mode
- FD002: Six operating conditions, single failure mode
- FD003: Single operating condition, two failure modes
- FD004: Six operating conditions, two failure modes

The trained models are saved to the models/saved directory.
"""

import os
import sys
import argparse
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import joblib
from datetime import datetime
import logging
import matplotlib.pyplot as plt
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

logger = logging.getLogger("DANCEST.CMAPSS.Training")

# Add parent directory to path to ensure imports work correctly
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

def download_cmapss_data():
    """
    Download the CMAPSS dataset if it doesn't already exist.
    Returns the path to the dataset directory.
    """
    # Create dataset directory - use the directory specified by the user
    dataset_dir = Path("data/@CMAPSS")
    dataset_dir.mkdir(parents=True, exist_ok=True)
    
    # Check if files already exist
    required_files = [
        f"train_FD001.txt", f"train_FD002.txt", f"train_FD003.txt", f"train_FD004.txt",
        f"RUL_FD001.txt", f"RUL_FD002.txt", f"RUL_FD003.txt", f"RUL_FD004.txt"
    ]
    
    all_exist = all((dataset_dir / file).exists() for file in required_files)
    
    if all_exist:
        logger.info("CMAPSS datasets already downloaded")
        return dataset_dir
    
    # Download data if not available
    import requests
    from io import BytesIO
    import zipfile
    
    logger.info("Downloading CMAPSS dataset...")
    url = "https://ti.arc.nasa.gov/c/6/"  # NASA C-MAPSS data URL
    
    try:
        response = requests.get(url)
        if response.status_code == 200:
            z = zipfile.ZipFile(BytesIO(response.content))
            z.extractall(dataset_dir)
            logger.info(f"Dataset downloaded to {dataset_dir}")
        else:
            logger.error(f"Failed to download dataset: {response.status_code}")
            logger.warning("Please download the dataset manually from https://ti.arc.nasa.gov/tech/dash/groups/pcoe/prognostic-data-repository/")
    except Exception as e:
        logger.error(f"Error downloading dataset: {e}")
        logger.warning("Please download the dataset manually from https://ti.arc.nasa.gov/tech/dash/groups/pcoe/prognostic-data-repository/")
    
    return dataset_dir

def load_cmapss_data(dataset_id, data_dir=None):
    """
    Load and preprocess CMAPSS data for the specified dataset ID.
    
    Args:
        dataset_id: Dataset ID (FD001, FD002, FD003, or FD004)
        data_dir: Optional explicit path to data directory
    
    Returns:
        X_train, y_train, X_val, y_val, X_test, y_test, scaler
    """
    logger.info(f"Loading CMAPSS {dataset_id} dataset")
    
    # Try multiple possible data locations
    possible_locations = []
    
    # If data_dir was provided, check it first
    if data_dir:
        possible_locations.append(Path(data_dir))
    
    # Add standard locations
    possible_locations.extend([
        Path("data/@CMAPSS"),
        Path("data/CMAPSS"),
        Path("data"),
        Path("DANCEST_model/data/CMAPSS"),
        Path("DANCEST_model/data"),
        Path("../data/CMAPSS"),
        Path("../data"),
        Path(".")  # Current directory
    ])
    
    # Find the first location that has the required file
    train_file_name = f"train_{dataset_id}.txt"
    dataset_dir = None
    train_file = None
    
    for location in possible_locations:
        # Check if the directory exists
        if not location.exists():
            continue
            
        # Check for the file directly in this directory
        if (location / train_file_name).exists():
            dataset_dir = location
            train_file = location / train_file_name
            logger.info(f"Found dataset files in: {dataset_dir}")
            break
            
        # Also check subdirectories one level down
        for subdir in location.iterdir():
            if subdir.is_dir() and (subdir / train_file_name).exists():
                dataset_dir = subdir
                train_file = subdir / train_file_name
                logger.info(f"Found dataset files in: {dataset_dir}")
                break
        
        if train_file:
            break
    
    if train_file is None:
        checked_paths = "\n  - ".join([str(p) for p in possible_locations])
        error_msg = f"Could not find {train_file_name} in any of these locations:\n  - {checked_paths}"
        logger.error(error_msg)
        raise FileNotFoundError(error_msg)
    
    # Define column names for the CMAPSS dataset
    cols = ['unit', 'cycle', 'op_setting_1', 'op_setting_2', 'op_setting_3'] + \
           [f'sensor_{i}' for i in range(1, 22)]
    
    # Load training data
    logger.info(f"Loading training data from: {train_file}")
    train_data = pd.read_csv(train_file, sep=" ", header=None, names=cols)
    train_data.dropna(axis=1, inplace=True)  # Drop NaN columns
    
    # Load RUL data
    rul_file = dataset_dir / f"RUL_{dataset_id}.txt"
    if not rul_file.exists():
        logger.error(f"RUL file {rul_file} not found")
        raise FileNotFoundError(f"RUL file {rul_file} not found")
    
    logger.info(f"Loading RUL data from: {rul_file}")
    rul_data = pd.read_csv(rul_file, sep=" ", header=None, names=['RUL'])
    rul_data.dropna(axis=1, inplace=True)  # Drop NaN columns
    
    # Process training data
    logger.info("Processing training data")
    
    # Group by unit
    grouped_data = train_data.groupby('unit')
    
    # Create RUL values for training data
    max_cycles = grouped_data['cycle'].max()
    
    train_data['RUL'] = train_data.apply(
        lambda row: max_cycles[row['unit']] - row['cycle'],
        axis=1
    )
    
    # Select features
    # Typically, operating settings and sensor readings are used as features
    feature_cols = ['op_setting_1', 'op_setting_2', 'op_setting_3'] + \
                   [f'sensor_{i}' for i in range(1, 22) if f'sensor_{i}' in train_data.columns]
    
    # For simplicity, let's use the same number of features for all datasets
    # We'll ensure each feature exists in the dataset
    valid_feature_cols = [col for col in feature_cols if col in train_data.columns]
    
    X = train_data[valid_feature_cols].values
    y = train_data['RUL'].values
    
    # Normalize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Split into train and validation sets
    X_train, X_val, y_train, y_val = train_test_split(
        X_scaled, y, test_size=0.2, random_state=42
    )
    
    # Create a small test set from validation data
    X_val, X_test, y_val, y_test = train_test_split(
        X_val, y_val, test_size=0.5, random_state=42
    )
    
    logger.info(f"Data loaded and processed: {X_train.shape[0]} training samples, "
                f"{X_val.shape[0]} validation samples, {X_test.shape[0]} test samples")
    
    return X_train, y_train, X_val, y_val, X_test, y_test, scaler, valid_feature_cols

def create_model(input_dim, learning_rate=1e-4):
    """
    Create a neural model for RUL prediction.
    
    Args:
        input_dim: Number of input features
        learning_rate: Learning rate for Adam optimizer
    
    Returns:
        Compiled Keras model
    """
    # Input layer
    inputs = Input(shape=(input_dim,))
    
    # Hidden layers with dropout and batch normalization
    x = Dense(128, activation='relu')(inputs)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    
    x = Dense(64, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    x = Dense(32, activation='relu')(x)
    x = Dropout(0.1)(x)
    
    # Output layer (RUL prediction)
    outputs = Dense(1, activation='linear')(x)
    
    # Create and compile model
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='mse',
        metrics=['mae']
    )
    
    return model

def train_model(dataset_id, verbose=False, data_dir=None):
    """
    Train a neural model for the specified CMAPSS dataset.
    
    Args:
        dataset_id: Dataset ID (FD001, FD002, FD003, or FD004)
        verbose: Whether to enable verbose output
        data_dir: Optional path to data directory
    
    Returns:
        model, scaler, history, val_mae
    """
    logger.info(f"Training neural model for CMAPSS {dataset_id}")
    
    # Load and preprocess data
    X_train, y_train, X_val, y_val, X_test, y_test, scaler, feature_cols = load_cmapss_data(dataset_id, data_dir)
    
    # Create model
    input_dim = X_train.shape[1]
    model = create_model(input_dim)
    
    # Setup callbacks
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    model_path = f"DANCEST_model/models/saved/cmapss_{dataset_id}_model_{timestamp}.keras"
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=12,
        restore_best_weights=True
    )
    
    checkpoint = ModelCheckpoint(
        model_path,
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    )
    
    # Train model
    logger.info(f"Starting training with {X_train.shape[0]} samples")
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=50,
        batch_size=64,
        callbacks=[early_stopping, checkpoint],
        verbose=1 if verbose else 2
    )
    
    # Evaluate model on test data
    test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
    logger.info(f"Test MAE: {test_mae:.2f}")
    
    # Save scaler
    scaler_path = f"DANCEST_model/models/saved/cmapss_{dataset_id}_scaler_{timestamp}.joblib"
    joblib.dump(scaler, scaler_path)
    logger.info(f"Model saved to: {model_path}")
    logger.info(f"Scaler saved to: {scaler_path}")
    
    # Plot training history
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title(f'CMAPSS {dataset_id} Model Loss')
    plt.ylabel('Loss (MSE)')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['mae'])
    plt.plot(history.history['val_mae'])
    plt.title(f'CMAPSS {dataset_id} Model MAE')
    plt.ylabel('MAE')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')
    
    # Save plot
    plots_dir = Path("DANCEST_model/results")
    plots_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(plots_dir / f"cmapss_{dataset_id}_training_{timestamp}.png")
    if verbose:
        plt.show()
    
    return model, scaler, history, test_mae

def train_all_datasets(datasets, verbose=False):
    """
    Train neural models for multiple CMAPSS datasets.
    
    Args:
        datasets: List of dataset IDs to train
        verbose: Whether to enable verbose output
    
    Returns:
        Dictionary of results for each dataset
    """
    results = {}
    
    for dataset_id in datasets:
        logger.info(f"=== Starting training for CMAPSS {dataset_id} ===")
        try:
            model, scaler, history, test_mae = train_model(dataset_id, verbose)
            results[dataset_id] = {
                'model': model,
                'scaler': scaler,
                'history': history,
                'test_mae': test_mae
            }
            logger.info(f"=== Completed training for CMAPSS {dataset_id} ===")
        except Exception as e:
            logger.error(f"Error training model for {dataset_id}: {e}")
            results[dataset_id] = {'error': str(e)}
    
    return results

def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Train neural models for CMAPSS datasets')
    parser.add_argument('--datasets', nargs='+', default=['FD001'],
                        help='CMAPSS datasets to train (FD001, FD002, FD003, FD004)')
    parser.add_argument('--verbose', action='store_true', help='Enable verbose output')
    parser.add_argument('--data_dir', type=str, default=None, 
                        help='Explicit path to the directory containing CMAPSS data files')
    args = parser.parse_args()
    
    # Check for valid datasets
    valid_datasets = ['FD001', 'FD002', 'FD003', 'FD004']
    datasets = [ds for ds in args.datasets if ds in valid_datasets]
    
    if not datasets:
        logger.error("No valid datasets specified")
        print("Please specify at least one valid dataset: FD001, FD002, FD003, FD004")
        return
    
    logger.info(f"Training models for datasets: {', '.join(datasets)}")
    
    # Create necessary directories
    os.makedirs("DANCEST_model/models/saved", exist_ok=True)
    os.makedirs("DANCEST_model/results", exist_ok=True)
    
    # If data_dir was provided, check it first
    if args.data_dir:
        data_dir = Path(args.data_dir)
        if data_dir.exists():
            logger.info(f"Using specified data directory: {data_dir}")
        else:
            logger.warning(f"Specified data directory {data_dir} does not exist")
    
    # Print information about where we're looking for data
    print("\nLooking for CMAPSS data files in these locations:")
    if args.data_dir:
        print(f"  - {args.data_dir} (user-specified)")
    print("  - data/@CMAPSS")
    print("  - data/CMAPSS")
    print("  - data")
    print("  - DANCEST_model/data/CMAPSS")
    print("  - And other standard locations\n")
    
    # Train models
    results = {}
    for dataset_id in datasets:
        logger.info(f"=== Starting training for CMAPSS {dataset_id} ===")
        try:
            model, scaler, history, test_mae = train_model(dataset_id, args.verbose, args.data_dir)
            results[dataset_id] = {
                'model': model,
                'scaler': scaler,
                'history': history,
                'test_mae': test_mae
            }
            logger.info(f"=== Completed training for CMAPSS {dataset_id} ===")
        except Exception as e:
            logger.error(f"Error training model for {dataset_id}: {e}")
            results[dataset_id] = {'error': str(e)}
    
    # Print summary
    print("\nTraining Summary:")
    for dataset_id, result in results.items():
        if 'error' in result:
            print(f"  {dataset_id}: Failed - {result['error']}")
        else:
            print(f"  {dataset_id}: Success - Test MAE: {result['test_mae']:.2f}")
            print(f"            Model saved to: DANCEST_model/models/saved/cmapss_{dataset_id}_model_*.keras")
    
    return results

if __name__ == "__main__":
    main() 