"""
Inference module for TurtleBot4 Navigation Surrogate Model.

This module provides a simple interface for loading and running predictions
with the trained surrogate model. It loads from a single pickle file containing
all ONNX models: 'turtlebot4_nav_model.pkl'
"""

import os
import pickle
import numpy as np
import onnxruntime as ort
from typing import Dict


def predict(X: np.ndarray) -> Dict[str, np.ndarray]:
    """
    Make predictions using the trained TurtleBot4 navigation surrogate model.
    
    This function loads the pre-trained model from 'turtlebot4_nav_model.pkl'
    and makes predictions on the input features. It supports batched inputs
    with any batch dimension as long as the feature dimension matches the
    expected input size (26 features).
    
    Args:
        X: Input features as numpy array with shape (batch_size, 26) or (26,).
           Features must be in the following order:
           [min_vel_x, max_vel_x, max_vel_theta, max_speed_xy, acc_lim_x,
            acc_lim_theta, decel_lim_x, decel_lim_theta, vx_samples, vtheta_samples,
            sim_time, BaseObstacle.scale, PathAlign.scale, GoalAlign.scale,
            PathDist.scale, GoalDist.scale, RotateToGoal.scale, local_width,
            local_height, local_resolution, local_inflation_radius,
            local_cost_scaling_factor, global_resolution, global_inflation_radius,
            global_cost_scaling_factor, fidelity]
           
           Note: fidelity must be one of [0.2, 0.5, 1.0]
    
    Returns:
        Dictionary mapping each target variable to its predicted values:
            - 'task_completion_rate': Probability of task completion [0, 1]
            - 'task_execution_time': Time to complete task in seconds
            - 'total_energy': Total energy consumed in Wh
            - 'distance_traveled': Total distance traveled in meters
            - 'collision_risk_score': Risk score for collisions [0, 1]
            - 'energy_per_meter': Energy efficiency in Wh/meter (computed)
        
        All values are numpy arrays with shape (batch_size,) or scalar if input
        was 1D.
    
    Raises:
        FileNotFoundError: If 'turtlebot4_nav_model.pkl' is not found in the same directory.
        ValueError: If input shape is invalid or fidelity values are not in [0.2, 0.5, 1.0].
    
    Example:
        >>> import numpy as np
        >>> # Single sample
        >>> x_single = np.array([...])  # 26 features
        >>> predictions = predict(x_single)
        >>> print(predictions['energy_per_meter'])
        
        >>> # Batch of samples
        >>> x_batch = np.random.randn(10, 26)  # 10 samples
        >>> x_batch[:, -1] = 1.0  # Set fidelity to 1.0
        >>> predictions = predict(x_batch)
        >>> print(predictions['energy_per_meter'].shape)  # (10,)
    """
    # Get the directory where this script is located
    script_dir = os.path.dirname(os.path.abspath(__file__))
    pickle_path = os.path.join(script_dir, 'turtlebot4_nav_model.pkl')
    
    # Load the pickle file containing all models and metadata
    if not os.path.exists(pickle_path):
        raise FileNotFoundError(
            f"Model file not found: {pickle_path}\n"
            f"Please train and save models first using model.py"
        )
    
    with open(pickle_path, 'rb') as f:
        saved_dict = pickle.load(f)
    
    # Reconstruct ONNX InferenceSession objects from bytes
    models = {}
    sess_options = ort.SessionOptions()
    sess_options.log_severity_level = 3  # Suppress warnings
    
    for target, onnx_bytes in saved_dict['onnx_bytes'].items():
        models[target] = ort.InferenceSession(onnx_bytes, sess_options=sess_options)
    
    feature_cols = saved_dict['feature_cols']
    target_cols = saved_dict['target_cols']
    
    # Validate input shape
    X = np.asarray(X)
    if X.ndim == 1:
        X = X.reshape(1, -1)
    elif X.ndim != 2:
        raise ValueError(f"Input must be 1D or 2D array, got shape {X.shape}")
    
    expected_features = len(feature_cols)
    if X.shape[1] != expected_features:
        raise ValueError(
            f"Input feature dimension mismatch. Expected {expected_features} features, "
            f"got {X.shape[1]}. Features must be: {feature_cols}"
        )
    
    # Validate fidelity values (fidelity is the last column)
    if 'fidelity' in feature_cols:
        fidelity_idx = feature_cols.index('fidelity')
        fidelity_values = np.round(X[:, fidelity_idx], decimals=1)
        valid_fidelities = np.array([0.2, 0.5, 1.0])
        
        # Check if all fidelity values are valid
        is_valid = np.isin(fidelity_values, valid_fidelities)
        if not np.all(is_valid):
            invalid_values = np.unique(fidelity_values[~is_valid])
            raise ValueError(
                f"Invalid fidelity value(s) found: {invalid_values}. "
                f"Fidelity must be one of: {valid_fidelities.tolist()}"
            )
    
    # Convert to numpy array and ensure proper format for ONNX
    X_array = X.astype(np.float32)
    
    # Make predictions using ONNX models
    result = {}
    for target in target_cols:
        onnx_session = models[target]
        input_name = onnx_session.get_inputs()[0].name
        output_name = onnx_session.get_outputs()[0].name
        pred = onnx_session.run([output_name], {input_name: X_array})[0]
        # Ensure output is 1D array regardless of ONNX model output shape
        result[target] = np.asarray(pred).ravel()
    
    # Compute energy_per_meter from total_energy and distance_traveled
    result['energy_per_meter'] = result['total_energy'] / result['distance_traveled']
    
    return result


if __name__ == "__main__":
    """
    Example usage demonstrating single and batch predictions.
    """
    print("="*80)
    print("INFERENCE EXAMPLE")
    print("="*80)
    
    # Example: Create a sample input (26 features)
    sample = np.array([
        -0.2,   # min_vel_x
        0.4,    # max_vel_x
        1.2,    # max_vel_theta
        0.45,   # max_speed_xy
        1.0,    # acc_lim_x
        2.5,    # acc_lim_theta
        -3.0,   # decel_lim_x
        -2.5,   # decel_lim_theta
        20,     # vx_samples
        20,     # vtheta_samples
        2.5,    # sim_time
        0.05,   # BaseObstacle.scale
        40.0,   # PathAlign.scale
        20.0,   # GoalAlign.scale
        30.0,   # PathDist.scale
        20.0,   # GoalDist.scale
        40.0,   # RotateToGoal.scale
        3.5,    # local_width
        3.5,    # local_height
        0.05,   # local_resolution
        0.4,    # local_inflation_radius
        8.0,    # local_cost_scaling_factor
        0.05,   # global_resolution
        0.4,    # global_inflation_radius
        5.0,    # global_cost_scaling_factor
        1.0     # fidelity (must be 0.2, 0.5, or 1.0)
    ])
    
    print("\nSingle sample prediction:")
    print(f"Input shape: {sample.shape}")
    
    try:
        predictions = predict(sample)
        print("\nPredictions:")
        for key, value in predictions.items():
            print(f"  {key}: {value[0]:.4f}")
    except FileNotFoundError as e:
        print(f"\nError: {e}")
        print("\nTo use this module, first train and save a model:")
        print("  python model.py")
        print("  # Then in the script, add:")
        print("  save_model(cb, 'rescue_bench/surrogates/turtlebot4_navigation/turtlebot4_nav_model.pkl')")
    
    print("\n" + "="*80)
    print("Batch prediction example:")
    print("="*80)
    
    # Create a batch of 5 samples
    batch_size = 5
    batch = np.tile(sample, (batch_size, 1))
    # Vary some parameters
    batch[:, 1] = np.linspace(0.3, 0.5, batch_size)  # max_vel_x
    batch[:, -1] = np.random.choice([0.2, 0.5, 1.0], batch_size)  # fidelity
    
    print(f"\nInput shape: {batch.shape}")
    
    try:
        predictions = predict(batch)
        print("\nPredictions for batch:")
        print(predictions)
        # for key, value in predictions.items():
        #     print(f"  {key}: {value}")
    except FileNotFoundError as e:
        print(f"\nSkipping batch example (model not found)")
    
    print("\n" + "="*80)
