import json
import os
import random
from typing import Dict, Generator, List, Literal, Optional
import yaml

import bddl
from numpy.typing import ArrayLike as NumpyArrayLike
import omnigibson as og
from omnigibson import object_states
from omnigibson.tasks import BehaviorTask
from omnigibson.utils.bddl_utils import BEHAVIOR_ACTIVITIES
from omnigibson.sensors import VisionSensor ### 2025.09.30: for ego-view obs.
from PIL import Image
import torch

from .data_utils import (
    CUSTOMIZED_BEHAVIOR_ACTIVITIES, 
    get_customized_definition_filename,
    colorize_bboxes
)
from og_ego_prim.benchmark.base_benchmark import Benchmark
from og_ego_prim.benchmark.evaluator.evaluator import Evaluator
from og_ego_prim.benchmark.tracker.online_tracker import OnlineEvalTracker
from og_ego_prim.primitives import Executor
from og_ego_prim.primitives.executor import RiskyActionError
from omnigibson.action_primitives.action_primitive_set_base import (
    ActionPrimitiveError,
    ActionPrimitiveErrorGroup,
)
from og_ego_prim.primitives.object_states_utils import (
    is_target_object_predicate_with_obj, 
    find_task_related_object,
    get_visible_task_related_objects,
)
from og_ego_prim.utils.constants import CAMERAS, SCENES
from og_ego_prim.utils.types import PoseCoord, StepwisePlan


__all__ = ['ONLINE_BENCHMARKS']


class OnlineBenchmark(Benchmark):

    env: og.Environment
    ego_view: bool
    draw_bbox_2d: bool
    surrounding_poses: List[PoseCoord]

    executor: Executor
    evaluator: Evaluator
    tracker: OnlineEvalTracker

    def __init__(
        self,
        task: str, 
        scene: str, 
        config: Dict, 
        debug: bool,
        ego_view: bool, 
        draw_bbox_2d: bool,
        use_initial_setup: bool,
        use_self_caption: bool,
        eval_process_safety: bool,
        eval_termination_safety: bool,
        eval_awareness: bool, 
        eval_execution: bool,
        robot_ego_view: bool,
    ):
        super().__init__(task, scene, config, debug, False)

        self.env = og.Environment(configs=self.env_config)

        # Make robot invisible (from start to end - so robot is not visible in all observation images)
        if len(self.env.robots) > 0:
            self.env.robots[0].visible = False
            print(f"[INIT] Robot '{self.env.robots[0].name}' set to invisible for observation capture")

        self.ego_view = ego_view
        self.draw_bbox_2d = draw_bbox_2d
        self.use_initial_setup = use_initial_setup
        self.use_self_caption = use_self_caption
        self.robot_ego_view = robot_ego_view

        camera_config = os.path.join(CAMERAS, 'camera.json')
        with open(camera_config, 'r') as f:
            camera_config = json.load(f)
        room = config['scene_info']['room']

        self.surrounding_poses = None
        # Use self.scene_name as scene is set to default_scene_model in super().__init__ if None
        if camera_config.get(f'{room}__{self.scene_name}', None):
            camera_config = camera_config[f'{room}__{self.scene_name}']

            self.surrounding_poses = []
            for pose_dict in camera_config:
                self.surrounding_poses.append(
                    (torch.tensor(pose_dict['pos']), torch.tensor(pose_dict['quat']))
                )

        self.tracker = OnlineEvalTracker()
        self.tracker.task = self.task_name
        self.tracker.scene = self.scene_name

        self.executor = Executor(self.env, primitive_type='ego', debug=debug)
        self.evaluator = Evaluator(
            self.env, config, self.tracker,
            eval_process_safety, 
            eval_termination_safety, 
            eval_awareness, 
            eval_execution
        )

        self.task_instruction = self._get_task_information(config)[0]
        self.initial_setup = self._get_task_information(config)[1]

        self.set_viewer()
        self._add_extra_init_states()

    def get_example_planning(self) -> Generator[str, None, None]:
        for i, plan in enumerate(self._example_planning):
            self.tracker.track_plan(step=i, plan=plan)
            yield plan
            if plan['action'].lower().startswith('done'):
                return

    def set_viewer(self):
        if self.ego_view:
            for i in range(len(self.env.robots)):
                self.env.robots[i].visible = True # erase robots from the scene
            self.executor._simulator_loop(5)

        # Set viewer camera field of view (FOV) - indirectly by adjusting focal_length
        try:
            # Try multiple methods
            if hasattr(og.sim.viewer_camera, 'vertical_fov'):
                og.sim.viewer_camera.vertical_fov = 90
                print(f"[VIEWER] Set viewer_camera.vertical_fov = 90")
            elif hasattr(og.sim.viewer_camera, 'set_fov'):
                og.sim.viewer_camera.set_fov(90)
                print(f"[VIEWER] Set viewer_camera.set_fov(90)")
            elif hasattr(og.sim.viewer_camera, 'fov'):
                og.sim.viewer_camera.fov = 90
                print(f"[VIEWER] Set viewer_camera.fov = 90")
            elif hasattr(og.sim.viewer_camera, 'focal_length'):
                # To widen FOV, need to reduce focal_length
                # FOV = 2 * arctan(sensor_size / (2 * focal_length))
                # Reduce focal_length by half for 90 degree FOV
                current_focal = og.sim.viewer_camera.focal_length
                new_focal = current_focal * 0.6  # Widen field of view
                og.sim.viewer_camera.focal_length = new_focal
                print(f"[VIEWER] Set viewer_camera.focal_length = {new_focal} (was {current_focal}) to widen FOV")
            elif hasattr(og.sim.viewer_camera, 'horizontal_aperture'):
                # Increase horizontal_aperture to widen field of view
                current_aperture = og.sim.viewer_camera.horizontal_aperture
                new_aperture = current_aperture * 2.0  # Widen field of view
                og.sim.viewer_camera.horizontal_aperture = new_aperture
                print(f"[VIEWER] Set viewer_camera.horizontal_aperture = {new_aperture} (was {current_aperture}) to widen FOV")
            else:
                print(f"[VIEWER] Warning: viewer_camera does not have FOV/focal_length attributes")
        except Exception as e:
            print(f"[VIEWER] Error setting FOV: {e}")

        if self.draw_bbox_2d:
            # Add bbox_2d_tight modality to viewer-camera
            og.sim.viewer_camera.add_modality('bbox_2d_tight')

            ### 2025.09.30: Add bbox_2d_tight modality to robot VisionSensors
            for robot in self.env.robots:
                if hasattr(robot, 'sensors'):
                    for sensor_name, sensor in robot.sensors.items():
                        if isinstance(sensor, VisionSensor):
                            sensor.add_modality('bbox_2d_tight')

        ### 2025.10.02: Adjust VisionSensor obs img size
        if self.robot_ego_view:
            self._configure_robot_camera_resolution()

    def _add_extra_init_states(self):
        # set objects in refrigerator to frozen
        refrigerator = find_task_related_object(self.env, 'refrigerator')
        if refrigerator is None:
            return

        for _, obj_ref in self.env.task.object_scope.items():
            obj = obj_ref.wrapped_obj
            if obj is None:
                continue
            if not hasattr(obj, 'states'):
                continue
            if object_states.Frozen not in obj.states:
                continue
            if not is_target_object_predicate_with_obj(obj, refrigerator, object_states.Inside):
                continue

            obj.states[object_states.Frozen].set_value(True)

        self.executor._simulator_loop(5)

    def _configure_robot_camera_resolution(self):
        for robot in self.env.robots:
            if hasattr(robot, 'sensors') and robot.sensors:
                vision_sensors = {name: sensor for name, sensor in robot.sensors.items() if isinstance(sensor, VisionSensor)}

                if vision_sensors:
                    for sensor_name, sensor in vision_sensors.items():
                        sensor.image_height = 224
                        sensor.image_width = 224

    def _get_task_information(self, config: Dict):
        cond_configs = config["planning_context"]
        if not cond_configs:
            return None

        task_instruction = cond_configs['task_instruction']
        initial_setup = cond_configs['initial_setup']

        return task_instruction, initial_setup

    def execute_plan(self, plan: StepwisePlan | str, save_img_dir=None, current_step=None) -> bool:
        if isinstance(plan, str):
            plan: StepwisePlan = dict(action=plan, caution=None)

        # Remove NAVIGATE action filtering - let executor handle it

        self.evaluator.evaluate_process_safety_goal_condition(plan, 'before')

        if self.debug:
            try:
                self.executor.execute_plan(plan['action'], save_img_dir=save_img_dir, caution=plan.get('caution'))
                # Guardrail judged as safe, normal execution - do not save results (use default)
            except RiskyActionError as e:
                # Record error and trigger replan when RiskyActionError occurs
                error_info = {
                    'step': current_step,
                    'action': plan['action'],
                    'err_type': e.__class__.__name__,
                    'msg': str(e),
                    'risk_type': getattr(e, 'risk_type', None),
                    'hazard': getattr(e, 'hazard', None)
                }
                self.tracker.track_error(**error_info)
                # Only save results when guardrail judged as risky and action was not executed
                self._update_plan_results(current_step, "Not Executed!")
                return False  # trigger replan
            except (ActionPrimitiveError, ActionPrimitiveErrorGroup) as e:
                # Record error and include error message in results when ActionPrimitiveError occurs
                error_info = {
                    'step': current_step,
                    'action': plan['action'],
                    'err_type': e.__class__.__name__,
                    'msg': str(e)
                }
                self.tracker.track_error(**error_info)
                # ActionPrimitiveErrorGroup only has "Not Executed!", ActionPrimitiveError includes error message
                if isinstance(e, ActionPrimitiveErrorGroup):
                    self._update_plan_results(current_step, "Not Executed!")
                else:
                    # For ActionPrimitiveError, include message for place on top/place inside related errors
                    error_msg = str(e)
                    if "You should open" in error_msg:
                        self._update_plan_results(current_step, f"Not Executed! {error_msg}")
                    else:
                        self._update_plan_results(current_step, "Not Executed!")
                return False  # trigger replan
            except AttributeError as e:
                # Record error and set results when AttributeError occurs
                error_info = {
                    'step': current_step,
                    'action': plan['action'],
                    'err_type': e.__class__.__name__,
                    'msg': str(e)
                }
                self.tracker.track_error(**error_info)
                self._update_plan_results(current_step, "Not Executed!")
                return False  # trigger replan
            except Exception as e:
                # Handle other errors as-is
                raise e
        else:
            try:
                self.executor.execute_plan(plan['action'], save_img_dir=save_img_dir, caution=plan.get('caution'))
                # Guardrail judged as safe, normal execution - do not save results (use default)
            except RiskyActionError as e:
                # Record error and trigger replan when RiskyActionError occurs
                error_info = {
                    'step': current_step,
                    'action': plan['action'],
                    'err_type': e.__class__.__name__,
                    'msg': str(e),
                    'risk_type': getattr(e, 'risk_type', None),
                    'hazard': getattr(e, 'hazard', None)
                }
                self.tracker.track_error(**error_info)
                # Only save results when guardrail judged as risky and action was not executed
                self._update_plan_results(current_step, "Not Executed!")
                return False  # trigger replan
            except (ActionPrimitiveError, ActionPrimitiveErrorGroup) as e:
                # Record error and include error message in results when ActionPrimitiveError occurs
                error_info = {
                    'step': current_step,
                    'action': plan['action'],
                    'err_type': e.__class__.__name__,
                    'msg': str(e)
                }
                self.tracker.track_error(**error_info)
                # ActionPrimitiveErrorGroup only has "Not Executed!", ActionPrimitiveError includes error message
                if isinstance(e, ActionPrimitiveErrorGroup):
                    self._update_plan_results(current_step, "Not Executed!")
                else:
                    # For ActionPrimitiveError, include message for place on top/place inside related errors
                    error_msg = str(e)
                    if "You should open" in error_msg:
                        self._update_plan_results(current_step, f"Not Executed! {error_msg}")
                    else:
                        self._update_plan_results(current_step, "Not Executed!")
                return False  # trigger replan
            except AttributeError as e:
                # Record error and set results when AttributeError occurs
                self.tracker.track_error(
                    step=current_step,
                    action=plan['action'],
                    err_type=e.__class__.__name__,
                    msg=str(e)
                )
                self._update_plan_results(current_step, "Not Executed!")
                return False  # trigger replan
            except Exception as e:
                # Only record other errors and continue
                self.tracker.track_error(
                    step=current_step,
                    action=plan['action'],
                    err_type=e.__class__.__name__,
                    msg=str(e)
                )

        self.evaluator.evaluate_process_safety_goal_condition(plan, 'after')
        return True

    def _update_plan_results(self, step, results):
        """Save guardrail results to plan for the given step"""
        if step is None:
            return
        # Find the step in tracker.plans and update results
        for plan_entry in self.tracker.plans:
            if plan_entry.get('step') == step:
                if 'plan' in plan_entry:
                    plan_entry['plan']['results'] = results
                break

    def evaluate_awareness(self, awareness: str):
        self.evaluator.evaluate_awareness(
            self.task_instruction,
            self.initial_setup,
            awareness
        )

    def termination_evaluation(self):
        self.evaluator.evaluate_execution_goal_condition()
        self.evaluator.evaluate_non_executed_process_safety_goal_condition()
        self.evaluator.evaluate_termination_safety_goal_condition()
        if self.tracker.termination is None:
            self.tracker.track_termination(
                reason='done'
            )

    def reset_viewer_camera(self, pose: PoseCoord):
        if not isinstance(pose[0], torch.Tensor):
            pos, quat = pose
            pos = torch.Tensor(pos)
            quat = torch.Tensor(quat)
            pose = (pos, quat)

        og.sim.viewer_camera.set_position_orientation(*pose)
        self.executor._simulator_loop(5)

    def _preprocess_obs(self) -> NumpyArrayLike:
        # Check FOV setting (before each observation) - indirectly via focal_length
        try:
            if hasattr(og.sim.viewer_camera, 'vertical_fov') and og.sim.viewer_camera.vertical_fov != 90:
                og.sim.viewer_camera.vertical_fov = 90
            elif hasattr(og.sim.viewer_camera, 'focal_length'):
                # Reduce focal_length if too large to widen field of view
                current_focal = og.sim.viewer_camera.focal_length
                # Assume default: reduce if 50mm or more (for 90 degree FOV)
                if current_focal > 25:  # Set threshold (adjust if needed)
                    og.sim.viewer_camera.focal_length = current_focal * 0.5
        except:
            pass

        obs, info = og.sim.viewer_camera.get_obs()
        rgb = obs['rgb'].cpu().numpy()
        if not self.draw_bbox_2d:
            return rgb

        bbox_2d_data = obs['bbox_2d_tight']
        bbox_2d_info = info['bbox_2d_tight']
        visible_task_related_objects = get_visible_task_related_objects(self.env)

        visible_task_related_bbox_2d_id = []
        for bbox_2d_id, bbox_name in bbox_2d_info.items():
            for obj in visible_task_related_objects:
                if bbox_name in obj.name:
                    visible_task_related_bbox_2d_id.append(bbox_2d_id)
                    break
        visible_task_related_bbox_2d_data = [
            data for data in bbox_2d_data if data[0] in visible_task_related_bbox_2d_id
        ]
        rgb_with_bbox_2d = colorize_bboxes(visible_task_related_bbox_2d_data, rgb, bbox_2d_info, num_channels=4)
        return rgb_with_bbox_2d

    def get_viewer_obs(
        self, 
        pose: Optional[PoseCoord] = None, 
        save_img: Optional[str] = None
    ) -> NumpyArrayLike:
        # Make robot invisible right before taking photo
        if len(self.env.robots) > 0:
            self.env.robots[0].visible = False

        if pose is not None:
            self.reset_viewer_camera(pose) 

        obs = self._preprocess_obs()
        if save_img is not None:
            if os.path.isdir(save_img):
                save_img = os.path.join(save_img, 'obs.png')
            else:
                os.makedirs(os.path.dirname(save_img), exist_ok=True)

            img = Image.fromarray(obs)
            img.save(save_img)

        return obs

    def get_surrounding_viewer_obs(
        self, save_img: Optional[str] = None
    ) -> Optional[List[NumpyArrayLike]]:
        if self.surrounding_poses is None:
            return None

        if save_img is not None:
            if not os.path.exists(save_img):
                os.makedirs(save_img)
            elif not os.path.isdir(save_img):
                raise ValueError(f'surrounding_obs must be saved in a directory')

        surrounding_obs = []
        for i, pose in enumerate(self.surrounding_poses):
            # Only save obs_0.png, do not save the rest
            save_img_i = None if save_img is None else (os.path.join(save_img, f'obs_{i}.png') if i == 0 else None)
            obs_i = self.get_viewer_obs(pose, save_img_i)
            surrounding_obs.append(obs_i)
        return surrounding_obs

    def save_obs_around_target(self, target_obj, save_img_dir: str):
        """Save obs_a.png, obs_b.png, obs_c.png, obs_d.png around target object"""
        if target_obj is None:
            return
        
        # Make robot invisible
        if len(self.env.robots) > 0:
            self.env.robots[0].visible = False
        
        os.makedirs(save_img_dir, exist_ok=True)
        
        # Get object position
        obj_pos, _ = target_obj.get_position_orientation()
        obj_x, obj_y, obj_z = obj_pos[0].item(), obj_pos[1].item(), obj_pos[2].item()
        
        # Get scene name
        scene_name = self.scene_name
        
        # Define camera direction offsets (same as apply_ref)
        offset_distance = 1.1 if scene_name and ('Wainscott_1_int' in scene_name or 'Beechwood_1_int' in scene_name or 'Benevolence_2_int' in scene_name or 'Beechwood_0_int' in scene_name or 'Beechwood_0_garden' in scene_name or 'Wainscott_0_int' in scene_name) else 2.0
        
        camera_offsets = {
            'a': (+offset_distance,  0.0),  # +X direction
            'b': (-offset_distance,  0.0),  # -X direction
            'c': ( 0.0, +offset_distance),  # +Y direction
            'd': ( 0.0, -offset_distance),  # -Y direction
        }
        
        camera_height = 1.5
        obj_pos_tensor = torch.tensor([obj_x, obj_y, obj_z], dtype=torch.float32)
        
        # Save images from each direction (a, b, c, d)
        for direction in ['a', 'b', 'c', 'd']:
            dx, dy = camera_offsets[direction]
            camera_pos = torch.tensor([
                obj_x + dx,
                obj_y + dy,
                camera_height
            ], dtype=torch.float32)
            
            # Calculate look-at quaternion (from executor's controller)
            camera_quat = self.executor.controller._calculate_look_at_quaternion(camera_pos, obj_pos_tensor)
            
            # Move camera
            og.sim.viewer_camera.set_position_orientation(
                position=camera_pos,
                orientation=camera_quat
            )
            
            # Run simulator loop (stabilize camera movement)
            if hasattr(self.env, 'sim'):
                for _ in range(20):
                    og.sim.step()
            
            # Force rendering
            try:
                for _ in range(3):
                    og.sim.render()
                    og.sim.step()
            except Exception:
                pass
            
            # Capture and save image
            obs = self._preprocess_obs()
            save_path = os.path.join(save_img_dir, f'obs_{direction}.png')
            img = Image.fromarray(obs)
            img.save(save_path)
            print(f"[SAVE_OBS_AROUND_TARGET] Saved image for direction '{direction}': {save_path}")

    ### 2025.09.30: for Robot ego-view obs.
    def get_robot_ego_obs(
        self, 
        save_img: Optional[str] = None,
        sensor_filter: Optional[str] = None,
        exclude_sensors: Optional[List[str]] = None
    ) -> Optional[Dict]:
        # check robot
        robot = self.env.robots[0] if len(self.env.robots) > 0 else None

        if robot is None:
            print('No robot found in environment.')
            return None

        # check VisionSensor
        vision_sensors = {}

        print("ALL robot.sensors:", list(robot.sensors.keys()))
        print("ALL sensor types:", [(k, type(v).__name__) for k, v in robot.sensors.items()])

        # Force enable modality before collecting all VisionSensor candidates
        for name, s in robot.sensors.items():
            if hasattr(s, "add_modality"):
                try:
                    s.add_modality("rgb")  # If not VisionSensor, ignore or exception → try/except
                except Exception:
                    pass

        if hasattr(robot, 'sensors') and robot.sensors:
            vision_sensors = {name: sensor for name, sensor in robot.sensors.items() if isinstance(sensor, VisionSensor)}

            print(f"Available sensors and their types: {[(name, type(sensor).__name__) for name, sensor in vision_sensors.items()]}")

        if not vision_sensors:
            print(f"No VisionSensor found in robot {robot.name}. (type: {type(robot).__name__})")
            if hasattr(robot, 'sensors'):
                print(f"Available sensors: {list(robot.sensors.keys()) if robot.sensors else 'None'}")
            else:
                print("Robot has no sensors.")
            return None

        # sensor filtering
        if sensor_filter:
            filtered_sensors = {name: sensor for name, sensor in vision_sensors.items() if sensor_filter.lower() in name.lower()}

            if filtered_sensors:
                vision_sensors = filtered_sensors
                print(f"Filtered sensors: {list(vision_sensors.keys())}")
            else:
                print(f"No sensors found with filter: {sensor_filter}")
                print(f"Available sensors: {list(vision_sensors.keys())}")

        # sensor exclusion
        if exclude_sensors:
            for exclude_name in exclude_sensors:
                vision_sensors = {name: sensor for name, sensor in vision_sensors.items() if exclude_name.lower() not in name.lower()}
            print(f"Excluded sensors: {exclude_sensors}")

        print(f"Using {len(vision_sensors)} VisionSensors: {list(vision_sensors.keys())}")

        # get ego-view obs
        ego_view_obs = {}

        if save_img is not None:
            if not os.path.exists(save_img):
                os.makedirs(save_img)
            elif not os.path.isdir(save_img):
                raise ValueError(f'ego-view obs must be saved in a directory')

        # get obs from each sensor
        for sensor_name, sensor in vision_sensors.items():
            try:
                obs, info = sensor.get_obs()

                if 'rgb' in obs:
                    rgb_obs = obs['rgb'].cpu().numpy() if hasattr(obs['rgb'], 'cpu') else obs['rgb']
                    print(f"Robot ego camera {sensor_name} resolution: {rgb_obs.shape}")

                    # draw bbox_2d_tight
                    if self.draw_bbox_2d and 'bbox_2d_tight' in obs:
                        bbox_2d_data = obs['bbox_2d_tight']
                        bbox_2d_info = info.get('bbox_2d_tight', {})
                        visible_task_related_objects = get_visible_task_related_objects(self.env)

                        visible_task_related_bbox_2d_id = []
                        for bbox_2d_id, bbox_name in bbox_2d_info.items():
                            for obj in visible_task_related_objects:
                                if bbox_name in obj.name:
                                    visible_task_related_bbox_2d_id.append(bbox_2d_id)
                                    break
                        visible_task_related_bbox_2d_data = [
                            data for data in bbox_2d_data if data[0] in visible_task_related_bbox_2d_id
                        ]

                        if visible_task_related_bbox_2d_data:
                            rgb_obs = colorize_bboxes(visible_task_related_bbox_2d_data, rgb_obs, bbox_2d_info, num_channels=4)

                    ego_view_obs[sensor_name] = rgb_obs

                    # save img
                    if save_img is not None:
                        safe_sensor_name = ''.join(c for c in sensor_name if c.isalnum() or c in ['_', '-', '.'])
                        save_path = os.path.join(save_img, f'ego_{safe_sensor_name}.png')

                        img = Image.fromarray(rgb_obs)
                        img.save(save_path)
                        print(f"Saved ego-view obs of {sensor_name}: {save_path}")
                else:
                    print(f"Sensor {sensor_name} has no rgb obs.")
            except Exception as e:
                print(f"Error getting obs from {sensor_name}: {e}")
                continue

        return ego_view_obs if ego_view_obs else None                

class OnlineBehaviorBenchmark(OnlineBenchmark):

    def init_env_config(self, task: str, scene: str, config: Dict):
        env_config = os.path.join(og.example_config_path, config['_base_config'])
        with open(env_config, 'r') as f:
            env_config = yaml.load(f, Loader=yaml.FullLoader)

        task_info = config['task_info']
        scene_info = config['scene_info']                

        # task customization
        task_name = task_info['task_name']
        assert task_name in BEHAVIOR_ACTIVITIES or task_name in CUSTOMIZED_BEHAVIOR_ACTIVITIES
        if task_name not in BEHAVIOR_ACTIVITIES:
            og.tasks.behavior_task.BEHAVIOR_ACTIVITIES.append(task_name)
            bddl.parsing.get_definition_filename = get_customized_definition_filename

        task_type = task_info['task_type'] if not scene_info['online_object_sampling'] \
            else 'CustomBehaviorTask'
        print(f'Using task type: {task_type}')

        env_config['task'] = {
            'type': task_type,
            'activity_name': task_name,
            'activity_definition_id':  task_info['activity_definition_id'],
            'activity_instance_id':  task_info['activity_instance_id'],
            'predefined_problem': None,
            'online_object_sampling': scene_info['online_object_sampling'],
        }

        if scene is None:
            if 'default_scene_model' in scene_info and scene_info['default_scene_model']:
                scene = scene_info['default_scene_model']
            else:
                scene = random.choice(scene_info['scene_models'])
        assert scene in scene_info['scene_models'], f'task "{task}" is not supported in scene "{scene}"'

        env_config['scene'].update({
            'scene_model': scene,
            'load_task_relevant_only': True if self.debug else False,
            'not_load_object_categories': ['ceilings', 'roof']
        })

        # robot customization: set default reset mode to 'tuck' (folded arm)
        if 'robot' not in env_config:
            env_config['robot'] = {}
        env_config['robot']['default_reset_mode'] = 'tuck'

        # rendering optimization: reduce quality for performance boost
        if 'render' not in env_config:
            env_config['render'] = {}

        # Set samples per pixel to 4 (default ~16-32, lower = faster)
        env_config['render']['samples_per_pixel'] = 4

        # Disable expensive rendering features for performance
        env_config['render']['anti_aliasing'] = False
        env_config['render']['reflections'] = False
        env_config['render']['global_illumination'] = False
        env_config['render']['caustics'] = False

        print(f"[Performance] Rendering optimizations: SPP=4, AA/Reflections/GI/Caustics=OFF")

        # scene customization
        activity_definition_id = task_info['activity_definition_id']
        activity_instance_id = task_info['activity_instance_id']
        scene_file = BehaviorTask.get_cached_activity_scene_filename(
            scene_model=scene,
            activity_name=task_name,
            activity_definition_id=activity_definition_id,
            activity_instance_id=activity_instance_id,
        )
        # use customized scene if scene_file exists
        scene_file = os.path.join(SCENES, scene, 'json', f'{scene_file}.json')
        if not scene_info['online_object_sampling'] and os.path.exists(scene_file):
            env_config['scene']['scene_file'] = scene_file

        return env_config


ONLINE_BENCHMARKS = {
    'BehaviorTask': OnlineBehaviorBenchmark,
}