from __future__ import annotations

import torch
from torch import Tensor

from botorch.test_functions.base import (
    MultiObjectiveTestProblem,
    ConstrainedBaseTestProblem
)

from rescue.problems.robotics.agv_navigation.surrogate import predict
from rescue.problems.robotics.agv_navigation.config import (
    RecommendedParameters
)

class AGVNavigation(
    MultiObjectiveTestProblem, 
    ConstrainedBaseTestProblem
):
    r"""
    The AGV Navigation problem for autonomous ground vehicles (AGVs).

    Objectives (2-D):
        - energy_per_meter (Wh/meter) (to minimize)
        - task_execution_time (to minimize)

    Constraints:
        - safety: collision_risk_score <= 5.5
        - task_completion_rate >= 0.8
    """

    dim = 26
    num_objectives = 2
    num_constraints = 2 

    design_var_names = [
        # Controller Velocity Limits (4 params)
        'min_vel_x',
        'max_vel_x',
        'max_vel_theta',
        'max_speed_xy',
        
        # Controller Acceleration Limits (4 params)
        'acc_lim_x',
        'acc_lim_theta',
        'decel_lim_x',
        'decel_lim_theta',
        
        # Controller Trajectory Sampling (3 params)
        'vx_samples',
        'vtheta_samples',
        'sim_time',
        
        # Controller DWB Critics (6 params)
        'BaseObstacle.scale',
        'PathAlign.scale',
        'GoalAlign.scale',
        'PathDist.scale',
        'GoalDist.scale',
        'RotateToGoal.scale',

        # Local Costmap (5 params)
        'local_width',           
        'local_height',          
        'local_resolution',      
        'local_inflation_radius', 
        'local_cost_scaling_factor', 

        # Global Costmap (3 params)
        'global_resolution',
        'global_inflation_radius',
        'global_cost_scaling_factor',

        # Fidelity Parameter (1 param)
        'fidelity'
    ]
    objective_var_names = [
        'energy_per_meter',
        'task_execution_time'
    ]
    constraint_var_names = [
        'collision_risk_score',
        'task_completion_rate'
    ]
    fidelity_param_name = 'fidelity'

    continuous_inds = list(range(26))
    discrete_inds = []
    categorical_inds = []

    fidelity_bounds = [(0.2, 1.0)] 
    _bounds = RecommendedParameters.get_bounds_array() + fidelity_bounds
    _ref_point = [0.4, 700.0] 
    _max_hv = 256.38205537634195 # Approximated using NSGA-II, seed: 510

    def __init__(
        self,
        noise_std: None | float | list[float] = None,
        negate: bool = False,
        dtype: torch.dtype = torch.double,
    ) -> None:
        r"""
        Args:
            noise_std: Standard deviation of the observation noise.
            negate: If True, negate the objectives.
            dtype: The dtype that is used for the bounds of the function.
        """
        super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
        self._cached_results = None

    def _measure_pref(self, X: Tensor) -> dict[str, Tensor]:
        """
        Measure preferences by running the task.
        
        Args:
            X: A tensor of shape (batch_size, d) or (d,) where d = 33.
            
        Returns:
            A dictionary containing performance metrics as tensors.
        """
        # Convert tensor to numpy for surrogate model prediction
        X_np = X.detach().cpu().numpy()
        
        # Get predictions from surrogate model
        predictions = predict(X_np)
        
        # Convert predictions back to tensors
        results = {
            key: torch.from_numpy(value).to(X.dtype).to(X.device)
            for key, value in predictions.items()
        }
        
        # Cache results for use in _evaluate_slack_true
        self._cached_results = results
        
        return results

    def _evaluate_true(self, X: Tensor) -> Tensor:
        """
        Evaluate the true objectives for input tensor X.

        Args:
            X: A tensor of shape (batch_size, d) or (d,) where d = 33 (32
                design variables + 1 fidelity parameter).
            
        Returns:
            A tensor of shape (batch_size, 2) or (2,) containing the objective values.
        """
        results = self._measure_pref(X)
        energy_per_meter = results['energy_per_meter']
        task_execution_time = results['task_execution_time']

        # Stack objectives along the last dimension
        # This handles both batched (batch_size,) and unbatched (scalar) cases
        return torch.stack([energy_per_meter, task_execution_time], dim=-1) 
    
    def _evaluate_slack_true(self, X: Tensor) -> Tensor:
        """
        Evaluate the constraint slack for input tensor X.
        Uses cached results from _measure_pref to avoid re-running the simulation.
        
        Args:
            X: A tensor of shape (batch_size, d) or (d,) where d = 33.
            
        Returns:
            A tensor of shape (batch_size, 2) or (2,) containing the constraint slack values.
        """
        if self._cached_results is None:
            results = self._measure_pref(X)
        else:
            results = self._cached_results
        
        collision_risk_score = results['collision_risk_score']
        task_completion_rate = results['task_completion_rate']
        
        # Constraint slack: negative values indicate constraint violation
        # For collision_risk_score: we want it to be <=5.5, so slack = 5.5 - collision_risk_score
        # For task_completion_rate: we want it >= 0.8, so slack = task_completion_rate - 0.8
        slack_collision = 5.5 - collision_risk_score
        slack_completion = task_completion_rate - 0.8
        
        # Stack constraints along the last dimension
        # This handles both batched (batch_size,) and unbatched (scalar) cases
        return torch.stack([slack_collision, slack_completion], dim=-1)