# Agent for functional calls

from pathlib import Path
import pickle

from agentdriver.functional_tools.detection import (
    get_leading_object_detection_info,
    get_surrounding_object_detections_info,
    get_front_object_detections_info,
    get_object_detections_in_range_info,
    get_all_object_detections_info,
    get_leading_object_detection,
    get_surrounding_object_detections,
    get_front_object_detections,
    get_object_detections_in_range,
    get_all_object_detections,
)

from agentdriver.functional_tools.prediction import (
    get_leading_object_future_trajectory_info,
    get_future_trajectories_for_specific_objects_info,
    get_future_trajectories_in_range_info,
    get_future_waypoint_of_specific_objects_at_timestep_info,
    get_all_future_trajectories_info,
    get_leading_object_future_trajectory,
    get_future_trajectories_for_specific_objects,
    get_future_trajectories_in_range,
    get_future_waypoint_of_specific_objects_at_timestep,
    get_all_future_trajectories,
)

from agentdriver.functional_tools.occupancy import (
    get_occupancy_at_locations_for_timestep_info,
    check_occupancy_for_planned_trajectory_info,
    get_occupancy_at_locations_for_timestep,
    check_occupancy_for_planned_trajectory,
)
from agentdriver.functional_tools.map import (
    get_drivable_at_locations_info,
    check_drivable_of_planned_trajectory_info,
    get_lane_category_at_locations_info,
    get_distance_to_shoulder_at_locations_info,
    get_current_shoulder_info,
    get_distance_to_lane_divider_at_locations_info,
    get_current_lane_divider_info,
    get_nearest_pedestrian_crossing_info,
    get_drivable_at_locations,
    check_drivable_of_planned_trajectory,
    get_lane_category_at_locations,
    get_distance_to_shoulder_at_locations,
    get_current_shoulder,
    get_distance_to_lane_divider_at_locations,
    get_current_lane_divider,
    get_nearest_pedestrian_crossing,    
)

from agentdriver.functional_tools.ego_state import get_ego_prompts

class FuncAgent:
    def __init__(self, data_dict) -> None:
        self.data = data_dict
        self.short_trajectory_description = False

        self.detection_func_infos = [
            get_leading_object_detection_info,
            get_object_detections_in_range_info,
            get_surrounding_object_detections_info,
            get_front_object_detections_info,
            get_all_object_detections_info,
        ]
        self.prediction_func_infos = [
            get_leading_object_future_trajectory_info,
            get_future_trajectories_for_specific_objects_info,
            get_future_trajectories_in_range_info,
            get_future_waypoint_of_specific_objects_at_timestep_info,
            get_all_future_trajectories_info,
        ]
        self.occupancy_func_infos = [
            get_occupancy_at_locations_for_timestep_info,
            # check_occupancy_for_planned_trajectory_info,
        ]
        self.map_func_infos = [
            get_drivable_at_locations_info,
            # check_drivable_of_planned_trajectory_info,
            get_lane_category_at_locations_info,
            get_distance_to_shoulder_at_locations_info,
            get_current_shoulder_info,
            get_distance_to_lane_divider_at_locations_info,
            get_current_lane_divider_info,
            get_nearest_pedestrian_crossing_info,
        ]
    

    """Detection functions""" 
    def get_leading_object_detection(self):
        return get_leading_object_detection(self.data)
    
    def get_surrounding_object_detections(self):
        return get_surrounding_object_detections(self.data)
    
    def get_front_object_detections(self):
        return get_front_object_detections(self.data)
    
    def get_object_detections_in_range(self, x_start, x_end, y_start, y_end):
        return get_object_detections_in_range(x_start, x_end, y_start, y_end, self.data)
    
    def get_all_object_detections(self):
        return get_all_object_detections(self.data)
    
    """Prediction functions"""
    def get_leading_object_future_trajectory(self):
        return get_leading_object_future_trajectory(self.data, short=self.short_trajectory_description)
    
    def get_future_trajectories_for_specific_objects(self, object_ids):
        return get_future_trajectories_for_specific_objects(object_ids, self.data, short=self.short_trajectory_description)
    
    def get_future_trajectories_in_range(self, x_start, x_end, y_start, y_end):
        return get_future_trajectories_in_range(x_start, x_end, y_start, y_end, self.data, short=self.short_trajectory_description)
        
    def get_future_waypoint_of_specific_objects_at_timestep(self, object_ids, timestep):
        return get_future_waypoint_of_specific_objects_at_timestep(object_ids, timestep, self.data)
    
    def get_all_future_trajectories(self):
        return get_all_future_trajectories(self.data, short=self.short_trajectory_description)
    
    """Occupancy functions"""
    def get_occupancy_at_locations_for_timestep(self, locations, timestep):
        return get_occupancy_at_locations_for_timestep(locations, timestep, self.data)
    
    def check_occupancy_for_planned_trajectory(self, trajectory):
        return check_occupancy_for_planned_trajectory(trajectory, self.data)

    """Map functions"""
    def get_drivable_at_locations(self, locations):
        return get_drivable_at_locations(locations, self.data)
    
    def check_drivable_of_planned_trajectory(self, trajectory):
        return check_drivable_of_planned_trajectory(trajectory, self.data)
    
    def get_lane_category_at_locations(self, locations, return_score=True):
        return get_lane_category_at_locations(locations, self.data, return_score=return_score)
    
    def get_distance_to_shoulder_at_locations(self, locations):
        return get_distance_to_shoulder_at_locations(locations, self.data)
    
    def get_current_shoulder(self):
        return get_current_shoulder(self.data)
    
    def get_distance_to_lane_divider_at_locations(self, locations):
        return get_distance_to_lane_divider_at_locations(locations, self.data)
    
    def get_current_lane_divider(self):
        return get_current_lane_divider(self.data)
    
    def get_nearest_pedestrian_crossing(self): 
        return get_nearest_pedestrian_crossing(self.data)  
    
    """Ego-state functions"""
    def get_ego_states(self):
        return get_ego_prompts(self.data)