import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
import os
import joblib
from datetime import datetime
from pathlib import Path
import itertools
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

def load_ANONYMIZED_corrosion_data(dataset_path='[ANONYMIZED]_lp_dataset', sample_size=None):
    """
    Load the corrosion dataset from the [ANONYMIZED] LP dataset.
    
    Args:
        dataset_path: Path to the dataset directory
        sample_size: Optional sample size for quicker testing (None for full dataset)
    
    Returns:
        features: Input features for training
        targets: Corrosion depth values (target)
        feature_names: Names of feature columns
    """
    print("Loading [ANONYMIZED] LP corrosion dataset...")
    
    # Try multiple possible locations for the corrosion data
    possible_paths = [
        os.path.join(dataset_path, '[ANONYMIZED]_lp_corrosion.csv'),
        os.path.join('..', dataset_path, '[ANONYMIZED]_lp_corrosion.csv'),
        os.path.join('..', '..', dataset_path, '[ANONYMIZED]_lp_corrosion.csv'),
        os.path.join('..', '..', '..', dataset_path, '[ANONYMIZED]_lp_corrosion.csv'),
        os.path.join('..', '..', 'adapted_test.csv'),
        os.path.join('..', 'adapted_test.csv'),
        os.path.join('..', '..', 'adapted_train.csv'),
        os.path.join('..', 'adapted_train.csv'),
        os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'adapted_test.csv')),
        os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'adapted_train.csv')),
        os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', dataset_path, '[ANONYMIZED]_lp_corrosion.csv')),
        '../../adapted_test.csv',
        '../../adapted_train.csv',
        '../adapted_test.csv',
        '../adapted_train.csv',
        'adapted_test.csv',
        'adapted_train.csv'
    ]
    
    corrosion_file = None
    for path in possible_paths:
        if os.path.exists(path):
            corrosion_file = path
            print(f"Found dataset at: {path}")
            break
    
    if not corrosion_file:
        raise FileNotFoundError(f"Could not find corrosion dataset at any of these locations: {possible_paths}")
    
    # Load a sample for quick development/testing
    if sample_size:
        corrosion_df = pd.read_csv(corrosion_file, nrows=sample_size)
    else:
        corrosion_df = pd.read_csv(corrosion_file)
    
    print(f"Loaded {len(corrosion_df)} corrosion data points")
    print(f"Columns: {corrosion_df.columns.tolist()}")
    
    # Extract target column (corrosion depth)
    target_col = 'corrosion_depth_mm'
    if target_col not in corrosion_df.columns:
        # Try alternative column names
        alternative_cols = ['corrosion_depth', 'depth_mm', 'RUL']
        for col in alternative_cols:
            if col in corrosion_df.columns:
                print(f"Using '{col}' as target column")
                target_col = col
                break
        else:
            raise ValueError(f"Could not find target column. Available columns: {corrosion_df.columns.tolist()}")
    
    # Identify feature columns (exclude ID columns and target)
    exclude_cols = ['blade_id', 'x_coord', 'y_coord', 'time_point', target_col]
    feature_cols = [col for col in corrosion_df.columns if col not in exclude_cols]
    
    # Add material properties if needed
    try:
        # Try multiple possible locations for materials data
        material_paths = [
            os.path.join(dataset_path, '[ANONYMIZED]_lp_materials.csv'),
            os.path.join('..', dataset_path, '[ANONYMIZED]_lp_materials.csv')
        ]
        
        materials_file = None
        for path in material_paths:
            if os.path.exists(path):
                materials_file = path
                break
        
        if materials_file:
            print(f"Loading material properties from {materials_file}")
            materials_df = pd.read_csv(materials_file)
            # Merge on blade_id
            corrosion_df = pd.merge(corrosion_df, materials_df, on='blade_id', how='left')
            
            # Add material columns to feature list
            material_cols = [col for col in materials_df.columns if col != 'blade_id']
            feature_cols.extend(material_cols)
    except Exception as e:
        print(f"Warning: Could not load material data: {e}")
    
    # Include time_point as a feature
    if 'time_point' in corrosion_df.columns:
        feature_cols.append('time_point')
    
    # One-hot encode categorical features
    categorical_cols = ['alloy_type', 'heat_treatment', 'surface_coating', 'manufacturing_batch']
    for col in categorical_cols:
        if col in corrosion_df.columns:
            # Create dummies and add to dataframe
            dummies = pd.get_dummies(corrosion_df[col], prefix=col)
            corrosion_df = pd.concat([corrosion_df, dummies], axis=1)
            
            # Add new columns to feature list and remove original column
            feature_cols.extend(dummies.columns)
            if col in feature_cols:
                feature_cols.remove(col)
    
    # Extract features and target
    X = corrosion_df[feature_cols].copy()
    y = corrosion_df[target_col].values.reshape(-1, 1)  # Reshape for keras
    
    print(f"Prepared data shapes - X: {X.shape}, y: {y.shape}")
    print(f"Feature columns: {feature_cols}")
    
    return X, y, feature_cols

def create_neural_model(input_dim, hidden_dim=256, dropout_rate=0.3, learning_rate=0.0001):
    """
    Create a physics-informed neural network for corrosion prediction with uncertainty estimation.
    
    Args:
        input_dim: Number of input features
        hidden_dim: Dimension of hidden layers
        dropout_rate: Dropout rate for uncertainty estimation
        learning_rate: Learning rate for optimizer
    
    Returns:
        model: Compiled Keras model
    """
    inputs = tf.keras.layers.Input(shape=(input_dim,))
    
    # Encode the feature space with configurable hidden dimension
    x = tf.keras.layers.Dense(hidden_dim, activation='relu')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(dropout_rate)(x)
    
    # Material degradation path
    mat_path = tf.keras.layers.Dense(hidden_dim//2, activation='relu')(x)
    mat_path = tf.keras.layers.Dropout(dropout_rate)(mat_path)
    
    # Time evolution path
    time_path = tf.keras.layers.Dense(hidden_dim//2, activation='relu')(x)
    time_path = tf.keras.layers.Dropout(dropout_rate)(time_path)
    
    # Extract time feature if available and compute physics-informed features
    time_indices = [i for i, name in enumerate(range(input_dim)) if 'time' in str(name).lower()]
    if time_indices:
        time_idx = time_indices[0]
        time_feature = tf.keras.layers.Lambda(lambda x: x[:, time_idx:time_idx+1])(inputs)
        sqrt_time = tf.keras.layers.Lambda(lambda x: tf.sqrt(x + 1e-6))(time_feature)
        combined = tf.keras.layers.Concatenate()([mat_path, time_path, sqrt_time])
    else:
        combined = tf.keras.layers.Concatenate()([mat_path, time_path])
    
    # Final dense layers
    x = tf.keras.layers.Dense(hidden_dim//1.5, activation='relu')(combined)
    x = tf.keras.layers.Dropout(dropout_rate)(x)
    x = tf.keras.layers.Dense(hidden_dim//3, activation='relu')(x)
    
    # Output corrosion prediction
    outputs = tf.keras.layers.Dense(1, activation='linear')(x)
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss='mse',
        metrics=['mae', tf.keras.metrics.RootMeanSquaredError(name='rmse')]
    )
    
    return model

def hyperparameter_tuning(X, y, val_split=0.1):
    """
    Perform hyperparameter tuning on the neural model using the specified grid.
    
    Args:
        X: Input features
        y: Target values
        val_split: Validation split percentage
    
    Returns:
        best_params: Best hyperparameters
        best_model: Best model
        best_scaler: Scaler used with best model
    """
    print("Starting hyperparameter tuning...")
    
    # Split data into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=val_split, random_state=42
    )
    
    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)
    
    # Define hyperparameter grid
    learning_rates = [1e-4, 5e-4, 1e-3]
    batch_sizes = [32, 64, 128]
    hidden_dims = [128, 256, 512]
    
    # Create model directory
    models_dir = Path('./models/tuning')
    models_dir.mkdir(parents=True, exist_ok=True)
    
    # Set up early stopping callback
    early_stopping = EarlyStopping(
        monitor='val_rmse',
        patience=12,
        restore_best_weights=True,
        mode='min'
    )
    
    # Perform grid search
    results = []
    best_val_rmse = float('inf')
    best_params = None
    best_model = None
    
    # Grid search
    total_combos = len(learning_rates) * len(batch_sizes) * len(hidden_dims)
    print(f"Testing {total_combos} hyperparameter combinations")
    
    for i, (lr, bs, hd) in enumerate(itertools.product(learning_rates, batch_sizes, hidden_dims)):
        print(f"\nTesting combination {i+1}/{total_combos}:")
        print(f"Learning rate: {lr}, Batch size: {bs}, Hidden dim: {hd}")
        
        # Create and compile model with current hyperparameters
        model = create_neural_model(
            input_dim=X_train_scaled.shape[1],
            hidden_dim=hd,
            learning_rate=lr
        )
        
        # Create checkpoint to save best model
        checkpoint_path = models_dir / f"model_lr{lr}_bs{bs}_hd{hd}.keras"
        checkpoint = ModelCheckpoint(
            str(checkpoint_path),
            monitor='val_rmse',
            save_best_only=True,
            mode='min'
        )
        
        # Train model
        history = model.fit(
            X_train_scaled, y_train,
            validation_data=(X_val_scaled, y_val),
            epochs=100,
            batch_size=bs,
            callbacks=[early_stopping, checkpoint],
            verbose=1
        )
        
        # Evaluate on validation set
        val_loss, val_mae, val_rmse = model.evaluate(X_val_scaled, y_val, verbose=0)
        
        # Record results
        results.append({
            'lr': lr,
            'batch_size': bs,
            'hidden_dim': hd,
            'val_loss': val_loss,
            'val_mae': val_mae,
            'val_rmse': val_rmse,
            'epochs_trained': len(history.history['loss'])
        })
        
        # Check if this is the best model so far
        if val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            best_params = {'lr': lr, 'batch_size': bs, 'hidden_dim': hd}
            best_model = model
            # Save best model and scaler
            model.save(models_dir / "best_model.keras")
            joblib.dump(scaler, models_dir / "best_scaler.joblib")
    
    # Convert results to DataFrame for analysis
    results_df = pd.DataFrame(results)
    results_df.to_csv(models_dir / "tuning_results.csv", index=False)
    
    print("\nHyperparameter tuning completed.")
    print(f"Best parameters: {best_params}")
    print(f"Best validation RMSE: {best_val_rmse:.4f}")
    
    return best_params, best_model, scaler

def train_final_model(X, y, best_params, test_size=0.2):
    """
    Train the final model using the best hyperparameters.
    
    Args:
        X: Input features
        y: Target values
        best_params: Best hyperparameters from tuning
        test_size: Test set size
    
    Returns:
        model: Trained model
        scaler: Fitted scaler
        evaluation: Evaluation metrics
    """
    print("\nTraining final model with best hyperparameters...")
    
    # Split data into training and test sets
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    
    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # Create model with best hyperparameters
    model = create_neural_model(
        input_dim=X_train_scaled.shape[1],
        hidden_dim=best_params['hidden_dim'],
        learning_rate=best_params['lr']
    )
    
    # Set up callbacks
    early_stopping = EarlyStopping(
        monitor='val_rmse',
        patience=12,
        restore_best_weights=True,
        mode='min'
    )
    
    # Create model directory
    models_dir = Path('./models/saved')
    models_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    checkpoint = ModelCheckpoint(
        str(models_dir / f"model_{timestamp}.keras"),
        monitor='val_rmse',
        save_best_only=True,
        mode='min'
    )
    
    # Train model
    history = model.fit(
        X_train_scaled, y_train,
        validation_data=(X_test_scaled, y_test),
        epochs=100,
        batch_size=best_params['batch_size'],
        callbacks=[early_stopping, checkpoint],
        verbose=1
    )
    
    # Evaluate model
    evaluation = model.evaluate(X_test_scaled, y_test)
    print(f"Test loss (MSE): {evaluation[0]:.4f}")
    print(f"Test MAE: {evaluation[1]:.4f}")
    print(f"Test RMSE: {evaluation[2]:.4f}")
    
    return model, scaler, evaluation

def save_model(model, scaler, feature_names=None):
    """
    Save the trained model and scaler to disk.
    
    Args:
        model: Trained Keras model
        scaler: Fitted StandardScaler
        feature_names: Names of feature columns (optional)
    """
    # Create directories
    models_dir = Path('./models/saved')
    models_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Save model
    model_path = models_dir / f'model_{timestamp}.keras'
    model.save(model_path)
    print(f"Model saved to: {model_path}")
    
    # Save scaler
    scaler_path = models_dir / f'scaler_{timestamp}.joblib'
    joblib.dump(scaler, scaler_path)
    print(f"Scaler saved to: {scaler_path}")
    
    # Save feature names if provided
    if feature_names:
        feature_path = models_dir / f'features_{timestamp}.txt'
        with open(feature_path, 'w') as f:
            for name in feature_names:
                f.write(f"{name}\n")
        print(f"Feature names saved to: {feature_path}")

def main():
    """Main function to train the DANCEST neural model with hyperparameter tuning."""
    # Load [ANONYMIZED] LP corrosion data - use full dataset for best results
    sample_size = None  # Use None for full dataset
    X, y, feature_names = load_ANONYMIZED_corrosion_data(sample_size=sample_size)
    
    # Perform hyperparameter tuning
    best_params, best_model, best_scaler = hyperparameter_tuning(X, y, val_split=0.1)
    
    # Train final model with best hyperparameters
    final_model, final_scaler, evaluation = train_final_model(X, y, best_params)
    
    # Save final model and scaler
    save_model(final_model, final_scaler, feature_names)
    
    print("DANCEST neural model training complete!")

if __name__ == "__main__":
    main() 