import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import os
import joblib
from datetime import datetime
from pathlib import Path
import itertools
import argparse
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Import the CMAPSS data loader
from cmapss.load_cmapss_data import load_cmapss

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

def create_neural_model(input_dim, hidden_dim=256, dropout_rate=0.3, learning_rate=0.0001):
    """
    Create a physics-informed neural network for RUL prediction with uncertainty estimation.
    
    Args:
        input_dim: Number of input features
        hidden_dim: Dimension of hidden layers
        dropout_rate: Dropout rate for uncertainty estimation
        learning_rate: Learning rate for optimizer
    
    Returns:
        model: Compiled Keras model
    """
    inputs = tf.keras.layers.Input(shape=(input_dim,))
    
    # Encode the feature space with configurable hidden dimension
    x = tf.keras.layers.Dense(hidden_dim, activation='relu')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(dropout_rate)(x)
    
    # Engine degradation path
    degradation_path = tf.keras.layers.Dense(hidden_dim//2, activation='relu')(x)
    degradation_path = tf.keras.layers.Dropout(dropout_rate)(degradation_path)
    
    # Operational condition path
    condition_path = tf.keras.layers.Dense(hidden_dim//2, activation='relu')(x)
    condition_path = tf.keras.layers.Dropout(dropout_rate)(condition_path)
    
    # Combine paths
    combined = tf.keras.layers.Concatenate()([degradation_path, condition_path])
    
    # Final dense layers
    x = tf.keras.layers.Dense(hidden_dim//1.5, activation='relu')(combined)
    x = tf.keras.layers.Dropout(dropout_rate)(x)
    x = tf.keras.layers.Dense(hidden_dim//3, activation='relu')(x)
    
    # Output RUL prediction
    outputs = tf.keras.layers.Dense(1, activation='linear')(x)
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss='mse',
        metrics=['mae', tf.keras.metrics.RootMeanSquaredError(name='rmse')]
    )
    
    return model

def hyperparameter_tuning(X, y, val_split=0.1):
    """
    Perform hyperparameter tuning on the neural model using the specified grid.
    
    Args:
        X: Input features
        y: Target values
        val_split: Validation split percentage
    
    Returns:
        best_params: Best hyperparameters
        best_model: Best model
        best_scaler: Scaler used with best model
    """
    print("Starting hyperparameter tuning...")
    
    # Split data into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=val_split, random_state=42
    )
    
    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)
    
    # Define hyperparameter grid
    learning_rates = [1e-4, 5e-4, 1e-3]
    batch_sizes = [32, 64, 128]
    hidden_dims = [128, 256, 512]
    
    # Create model directory
    models_dir = Path('./models/tuning')
    models_dir.mkdir(parents=True, exist_ok=True)
    
    # Set up early stopping callback
    early_stopping = EarlyStopping(
        monitor='val_rmse',
        patience=12,
        restore_best_weights=True,
        mode='min'
    )
    
    # Perform grid search
    results = []
    best_val_rmse = float('inf')
    best_params = None
    best_model = None
    
    # Grid search
    total_combos = len(learning_rates) * len(batch_sizes) * len(hidden_dims)
    print(f"Testing {total_combos} hyperparameter combinations")
    
    for i, (lr, bs, hd) in enumerate(itertools.product(learning_rates, batch_sizes, hidden_dims)):
        print(f"\nTesting combination {i+1}/{total_combos}:")
        print(f"Learning rate: {lr}, Batch size: {bs}, Hidden dim: {hd}")
        
        # Create and compile model with current hyperparameters
        model = create_neural_model(
            input_dim=X_train_scaled.shape[1],
            hidden_dim=hd,
            learning_rate=lr
        )
        
        # Create checkpoint to save best model
        checkpoint_path = models_dir / f"model_lr{lr}_bs{bs}_hd{hd}.keras"
        checkpoint = ModelCheckpoint(
            str(checkpoint_path),
            monitor='val_rmse',
            save_best_only=True,
            mode='min'
        )
        
        # Train model
        history = model.fit(
            X_train_scaled, y_train,
            validation_data=(X_val_scaled, y_val),
            epochs=100,
            batch_size=bs,
            callbacks=[early_stopping, checkpoint],
            verbose=1
        )
        
        # Evaluate on validation set
        val_loss, val_mae, val_rmse = model.evaluate(X_val_scaled, y_val, verbose=0)
        
        # Record results
        results.append({
            'lr': lr,
            'batch_size': bs,
            'hidden_dim': hd,
            'val_loss': val_loss,
            'val_mae': val_mae,
            'val_rmse': val_rmse,
            'epochs_trained': len(history.history['loss'])
        })
        
        # Check if this is the best model so far
        if val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            best_params = {'lr': lr, 'batch_size': bs, 'hidden_dim': hd}
            best_model = model
            # Save best model and scaler
            model.save(models_dir / "best_model.keras")
            joblib.dump(scaler, models_dir / "best_scaler.joblib")
    
    # Convert results to DataFrame for analysis
    results_df = pd.DataFrame(results)
    results_df.to_csv(models_dir / "tuning_results.csv", index=False)
    
    print("\nHyperparameter tuning completed.")
    print(f"Best parameters: {best_params}")
    print(f"Best validation RMSE: {best_val_rmse:.4f}")
    
    return best_params, best_model, scaler

def train_final_model(X, y, best_params, test_size=0.2):
    """
    Train the final model using the best hyperparameters.
    
    Args:
        X: Input features
        y: Target values
        best_params: Best hyperparameters from tuning
        test_size: Test set size
    
    Returns:
        model: Trained model
        scaler: Fitted scaler
        evaluation: Evaluation metrics
    """
    print("\nTraining final model with best hyperparameters...")
    
    # Split data into training and test sets
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    
    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # Create model with best hyperparameters
    model = create_neural_model(
        input_dim=X_train_scaled.shape[1],
        hidden_dim=best_params['hidden_dim'],
        learning_rate=best_params['lr']
    )
    
    # Set up callbacks
    early_stopping = EarlyStopping(
        monitor='val_rmse',
        patience=12,
        restore_best_weights=True,
        mode='min'
    )
    
    # Create model directory
    models_dir = Path('./models/saved')
    models_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    checkpoint = ModelCheckpoint(
        str(models_dir / f"cmapss_model_{timestamp}.keras"),
        monitor='val_rmse',
        save_best_only=True,
        mode='min'
    )
    
    # Train model
    history = model.fit(
        X_train_scaled, y_train,
        validation_data=(X_test_scaled, y_test),
        epochs=100,
        batch_size=best_params['batch_size'],
        callbacks=[early_stopping, checkpoint],
        verbose=1
    )
    
    # Evaluate model
    evaluation = model.evaluate(X_test_scaled, y_test)
    print(f"Test loss (MSE): {evaluation[0]:.4f}")
    print(f"Test MAE: {evaluation[1]:.4f}")
    print(f"Test RMSE: {evaluation[2]:.4f}")
    
    return model, scaler, evaluation

def save_model(model, scaler, feature_names=None, dataset_name="FD001"):
    """
    Save the trained model and scaler to disk.
    
    Args:
        model: Trained Keras model
        scaler: Fitted StandardScaler
        feature_names: Names of feature columns (optional)
        dataset_name: CMAPSS dataset name (e.g., FD001)
    """
    # Create directories
    models_dir = Path('./models/saved')
    models_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Save model
    model_path = models_dir / f'cmapss_{dataset_name}_model_{timestamp}.keras'
    model.save(model_path)
    print(f"Model saved to: {model_path}")
    
    # Save scaler
    scaler_path = models_dir / f'cmapss_{dataset_name}_scaler_{timestamp}.joblib'
    joblib.dump(scaler, scaler_path)
    print(f"Scaler saved to: {scaler_path}")
    
    # Save feature names if provided
    if feature_names:
        feature_path = models_dir / f'cmapss_{dataset_name}_features_{timestamp}.txt'
        with open(feature_path, 'w') as f:
            for name in feature_names:
                f.write(f"{name}\n")
        print(f"Feature names saved to: {feature_path}")

def main():
    """Main function to train the DANCEST neural model with hyperparameter tuning using CMAPSS dataset."""
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Train DANCEST neural model with CMAPSS dataset')
    parser.add_argument('--dataset', type=str, default='FD001', choices=['FD001', 'FD002', 'FD003', 'FD004'],
                      help='CMAPSS dataset to use (default: FD001)')
    args = parser.parse_args()
    
    # Load CMAPSS data for the specified dataset
    print(f"Loading CMAPSS {args.dataset} dataset...")
    X, y, feature_names = load_cmapss(args.dataset)
    
    print(f"Loaded dataset with {X.shape[0]} samples and {X.shape[1]} features")
    print(f"Features: {feature_names}")
    
    # Perform hyperparameter tuning
    best_params, best_model, best_scaler = hyperparameter_tuning(X, y, val_split=0.1)
    
    # Train final model with best hyperparameters
    final_model, final_scaler, evaluation = train_final_model(X, y, best_params)
    
    # Save final model and scaler
    save_model(final_model, final_scaler, feature_names, args.dataset)
    
    print(f"DANCEST neural model training complete for CMAPSS {args.dataset}!")

if __name__ == "__main__":
    main() 