"""
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import copy
import gymnasium as gym
import numpy as np
import math

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.lines as lines

import json
from gym.spaces import Dict

from utils.utils import set_seed
from utils.edge_space import Edge, EdgeInstance

from scheduling.location import Location
from scheduling.agent import Agent
from scheduling.task import Task
from scheduling.obstacle import Obstacle

from path_planning.rrg import RRG

from solvers.milp_solver import solve_with_MILP
        
class SchedulingEnvironment(gym.Env):
    def __init__(self, config_file = 'test/sample_problem_5_10_example.json', 
                 generator_config = 'config/config_5_10.json', 
                 mode='dijkstra', verbose=False, 
                 reward_mode="base", merged=False,
                 obstacles = None):
        """Initialize the Scheduling Environment
        Args:
            config_file (str): The configuration file for the environment. If None, then the environment is generated randomly
            generator_config (str): The configuration file for the generator. Default is 'config/config_5_10.json'
            mode (str): The mode to calculate the travel distance. Default is 'euclidean' for the euclidean distance. Other options are 'manhattan' and 'dijkstra'
        """
        super().__init__()
        self.mode = mode
        self.reward_mode = reward_mode
        self.verbose = verbose
        self.save_location = config_file
        if merged:
            return
        if config_file is None or not os.path.exists(config_file):
            with open(generator_config, 'r') as f:
                generator_config = json.load(f)
            self._generate_environment(generator_config, obstacles)
            if self.save_location is not None:
                self.save()
        else:
            with open(config_file, 'r') as f:
                config = json.load(f)
            self._setup_config(config)
        self._get_path_planning(config_file)
        self._generate_travel_times()

        if config_file is not None and not os.path.exists(config_file):
            self.save()
        # Observation Space
        self._setup_observation_space()
        # Action Space
        self._setup_action_space()
        self.time_step = 0
        self.num_infeasible = 0
        
    def get_schedule_location(self):
        return os.path.splitext(self.save_location)[0] + '_schedule.json'

    def _generate_environment(self, gen_config, obstacles = None):
        """Generate the environment from the configuration
        Args:
            config (dict): The configuration of the environment
        """
        if self.verbose:
            print("Generating Environment...")
        feasible = False
        while not feasible:
            config = SchedulingEnvironment._generate_random_config(gen_config, obstacles)
            self._setup_config(config, obstacles)
            self._get_path_planning(self.save_location)
            self._generate_travel_times()
            feasible, schedule = self.solve()
        # save the schedule into a seperate file
        schedule_path = self.get_schedule_location() 
        with open(schedule_path, 'w') as f:
            # write the schedule into the file
            f.write(", ".join([str(s) for s in schedule]))
            
        self.save()
        
    def _get_schedule_from_file(self):
        schedule_path = self.get_schedule_location()
        if not os.path.exists(schedule_path):
            return None
        with open(schedule_path, 'r') as f:
            schedule = f.read()
        # convert the schedule into a list of tuples
        steps = schedule[1:-1].split("), (")
        indexed_schedule = []
        for step in steps:
            task_id, agent_id = step.split(", ")
            indexed_schedule.append((int(task_id), int(agent_id)))
        return indexed_schedule
    
    def _setup_config(self, config, obstacles = None):
        """Setup the configuration of the environment
        Args:
            config (dict): The configuration of the environment
        """
        self.num_agents = config['num_agents']
        self.num_tasks = config['num_tasks']
        self.num_obstacles = config['num_obstacles']
        
        self.deadline = config['max_deadline']
        self.max_speed = config['max_speed']
        self.width = config['map']['width']
        self.height = config['map']['height']
        
        self._setup_environment_elements(config, obstacles)
        ## Durations
        self.duration = config['durations']
        self.wait_time_constraints = {
            int(task_id): wait_times for task_id, wait_times in config['wait_time'].items()
        }
        
    def _setup_environment_elements(self, config, obstacles=None):
        ## Tasks
        self.tasks = [
            Task(task['id'], 
                 task['x'], 
                 task['y'], 
                 start_time = task['start_time'], 
                 end_time = task['end_time']
            ) for task in config['tasks']
        ]
        
        ## Agents
        self.agents = [
            Agent(
                agent['id'],
                agent['x'],
                agent['y'],
                speed = agent['speed'],
                map_size = (self.width, self.height),
                num_tasks = self.num_tasks
            ) for agent in config['agents']
        ]
        
        ## Obstacles
        if obstacles is not None:
            self.obstacles = obstacles
        else:
            self.obstacles = [Obstacle(obstacle['x'], obstacle['y'], obstacle['radius']) for obstacle in config['obstacles']]
        
    def _setup_observation_space(self):
        # Observation Space
        self.observation_space = gym.spaces.Dict(
            {
                # Node Features
                ## Agent Node Features
                'agent': gym.spaces.Box(low = 0, high = 1, 
                    shape = (self.num_agents, Agent.get_feature_size())),
                ## Task Node Features
                'task': gym.spaces.Box(low = 0, high = 1, 
                    shape = (self.num_tasks, Task.get_feature_size())),
                'state': gym.spaces.Box(low = 0, high = 1, shape=(1, SchedulingEnvironment.get_state_feature_space())),
                # Edge Features
                ## Agent Task Edge Features
                'agent_to_task': Edge(
                    source_space=gym.spaces.Discrete(self.num_agents),
                    target_space=gym.spaces.Discrete(self.num_tasks),
                    edge_space=gym.spaces.Box(0, 1, shape=SchedulingEnvironment.get_agent_task_feature_space())
                ),
                # Travel from Task to Task
                'travel': Edge(
                    source_space=gym.spaces.Discrete(self.num_tasks),
                    target_space=gym.spaces.Discrete(self.num_tasks),
                    edge_space=gym.spaces.Box(0, 1, shape=SchedulingEnvironment.get_task_to_task_travel_feature_space())
                ),
                # Wait Time Edge Features (Prerequisite, Task, Wait Time Edge Features)
                'wait_time': Edge(
                    source_space=gym.spaces.Discrete(self.num_tasks),
                    target_space=gym.spaces.Discrete(self.num_tasks),
                    edge_space=gym.spaces.Box(0, 1, shape=SchedulingEnvironment.get_wait_time_feature_space())
                ),
                # Assigned
                'assigned': Edge(
                    source_space=gym.spaces.Discrete(self.num_agents),
                    target_space=gym.spaces.Discrete(self.num_tasks)
                ),
                # State
                'agent_to_state': Edge(
                    source_space=gym.spaces.Discrete(self.num_agents),
                    target_space=gym.spaces.Discrete(1)
                ),
                'task_to_state': Edge(
                    source_space=gym.spaces.Discrete(self.num_tasks),
                    target_space=gym.spaces.Discrete(1)
                ),
                # Agent Task Task Travel Time for Heterogenous Durations
                'task_assignment': gym.spaces.Box(low=0, high=1, shape=(
                    self.num_agents * self.num_tasks, 
                    SchedulingEnvironment.get_task_assignment_feature_space())
                ),
                'agent_to_task_assignment': Edge(
                    source_space=gym.spaces.Discrete(self.num_agents),
                    target_space=gym.spaces.Discrete(self.num_agents * self.num_tasks),
                    edge_space=gym.spaces.Box(0, 1, shape=SchedulingEnvironment.get_agent_to_task_assignment_feature_space())
                ),
                'task_to_task_assignment': Edge(
                    source_space=gym.spaces.Discrete(self.num_tasks),
                    target_space=gym.spaces.Discrete(self.num_agents * self.num_tasks)
                ),
                'task_assignment_to_task_assignment': Edge(
                    source_space=gym.spaces.Discrete(self.num_agents * self.num_tasks),
                    target_space=gym.spaces.Discrete(self.num_agents * self.num_tasks),
                    edge_space=gym.spaces.Box(0, 1, shape=SchedulingEnvironment.get_agent_task_task_travel_time_feature_space())
                ),
                # 'shared_edge_features': Edge(
                #     source_space=gym.spaces.Discrete(self.num_agents),
                #     target_space=gym.spaces.Discrete(len(self.RRG.shared_edge_set)),
                #     edge_space=gym.spaces.Box(0, 1, shape=SchedulingEnvironment.get_shared_edge_features_space())
                #     # (start_time, duration, end_time)
                # )
                'task_to_task_select': Edge(
                    source_space=gym.spaces.Discrete(self.num_tasks),
                    target_space=gym.spaces.Discrete(self.num_tasks)
                )
            }
        )
    
    def _setup_action_space(self):        
        self.action_space = gym.spaces.Tuple((
            gym.spaces.Discrete(self.num_agents),
            gym.spaces.Discrete(self.num_tasks)
        ))
    
    def sample_action_space(self):
        """Sample an action from the action space. Tasks are executed once and cannot repeat. Agents can repeat.        
        Returns:
            tuple: The action sampled from the action space
        """
        # sample num_tasks tasks from the action space without replacement
        task_ids = np.random.choice(self.num_tasks, self.num_tasks, replace=False)
        agent_ids = np.random.choice(self.num_agents, self.num_tasks, replace=True)
        actions = [(task_id, agent_id) for task_id, agent_id in zip(task_ids, agent_ids)]
        return actions
    
    def _generate_travel_times(self):
        """Generate the initial travel distance for the agents and the tasks
        """
        if self.mode in ['euclidean', 'manhattan']:
            self.initial_travel_distance = {(i, j): self.agents[i].location.distance(self.tasks[j].location, mode=self.mode) for i in range(self.num_agents) for j in range(self.num_tasks)}
            
            self.task_travel_distance = {(i, j): self.tasks[i].location.distance(self.tasks[j].location, mode=self.mode) for i in range(self.num_tasks) for j in range(self.num_tasks)}
        elif self.mode in ['dijkstra']:
            self.initial_travel_distance = {(i, j): self.RRG.get_agent_to_task_distance(i, j) for i in range(self.num_agents) for j in range(self.num_tasks)}
            
            self.task_travel_distance = {(i, j): self.RRG.get_task_to_task_distance(i, j) for i in range(self.num_tasks) for j in range(i+1,self.num_tasks)}
    
    def load(self, filename):
        with open(filename, 'r') as f:
            config = json.load(f)
            self._setup_config(config)
        
    def save(self, filename = None):
        config = {
            'num_agents': self.num_agents,
            'num_tasks': self.num_tasks,
            'num_obstacles': self.num_obstacles,
            'max_speed': self.max_speed,
            'max_deadline': self.deadline,
            'map': {
                'width': self.width,
                'height': self.height
            },
            'tasks': [task.__dict__() for task in self.tasks],
            'agents': [agent.__dict__() for agent in self.agents],
            'durations': self.duration,
            'wait_time': self.wait_time_constraints,
            'obstacles': [obstacle.__dict__() for obstacle in self.obstacles]
        }
        if self.save_location is not None and filename is None:
            filename = self.save_location
        with open(filename, 'w') as f:
            json.dump(config, f, indent=4)
    
    def get_travel_time(self, agent_id, task_id, other_task_id = None):
        if other_task_id is None:
            if self.agents[agent_id].assigned_tasks:
                return self._get_task_task_travel_time(agent_id, self.agents[agent_id].assigned_tasks[-1].id, task_id)
            else:
                return self._get_agent_task_travel_time(agent_id, task_id)
        else:
            return self._get_task_task_travel_time(agent_id, task_id, other_task_id)
        
    def _get_agent_task_travel_time(self, agent_id, task_id):
        return self.initial_travel_distance[(agent_id, task_id)] / self.agents[agent_id].speed
     
    def _get_task_task_travel_time(self, agent_id, task_id, other_task_id):
        smaller_task_id = min(task_id, other_task_id)
        larger_task_id = max(task_id, other_task_id)
        return self.task_travel_distance[(smaller_task_id, larger_task_id)] / self.agents[agent_id].speed
    
    def step(self, action):
        """Take a step in the environment
        Args:
            action (tuple): The action to take in the environment. The action is a tuple of (task_id, agent_id)
        Returns:
            tuple: The observation of the environment, the reward, the done flag, and the info dictionary"""
        info = {}
        previous_makespan = max([agent.current_makespan for agent in self.agents])
        task_id, agent_id = action       
        current_makespan = self.agents[agent_id].current_makespan
        if self.mode in ['dijkstra']:
            travel_time = 0
            # Get the path from the current location of the agent and the task
            if self.agents[agent_id].assigned_tasks:
                path = self.RRG.task_to_task_paths[(self.agents[agent_id].assigned_tasks[-1], task_id)]
            else:
                path = self.RRG.agent_to_task_paths[(agent_id, task_id)]
            travel_time = path.length / self.agents[agent_id].speed
            edges = [Edge(path[i], path[i+1]) for i in range(len(path) - 1)]
            for edge in edges:
                start_time = current_makespan + travel_time
                arrival_times.append(current_makespan + travel_time)
                travel_time += edge.distance / self.agents[agent_id].speed
                if edge not in self.RRG.shared_edge_set:
                    continue
                # reserve the edge
                edge.reserve(agent_id, start_time, start_time + travel_time)
                # edge.length / self.agents[agent_id].speed for edge in edges]
                shared_edges = [edge for edge in edges if edge in self.RRG.shared_edge_set]
        elif self.mode in ['euclidean', 'manhattan']:
            # travel_time = self.get_travel_time(agent_id, task_id)
            pass
        arrival_times = []
                
        
        # Check if the task can be feasibly completed
        makespan, reach_time, feasible = self._update_makespan(task_id, agent_id)
        
        # feasible = makespan is not None
        if not feasible:
            # makespan = reach_time
            self.num_infeasible += 1
        # Update the agent
        self.agents[agent_id].step(self.tasks[task_id], makespan)
        # Update the task   
        self.tasks[task_id].step(self.agents[agent_id], makespan, feasible)
        self.time_step += 1
        # Update the task
        # if not self.tasks[task_id].step(self.agents[agent_id]):
        #     # assigned to already assigned agent
        #     feasibility = False
        done = (self.time_step == self.num_tasks) # or feasible is False
        # done = done or feasible is False
        if done:
            # num_unassigned = len([task for task in self.tasks if not task.task_will_complete])
            # # num_unassigned = self.num_infeasible
            # if feasible:
            #     previous_makespan = max([agent.current_makespan for agent in self.agents])
            # reward = -1.0 * previous_makespan - num_unassigned * self.deadline  # Penalize for unassigned tasks
            # reward /= (self.deadline * self.num_tasks) # Normalize the reward
            # reward = -1.0 * self.num_infeasible / self.num_tasks
            # reward = 1.0 + reward
            observation = None
        else:
            observation = self.get_observation()
        
        # if 'makespan' in self.reward_mode:
        # current_makespan = max([agent.current_makespan for agent in self.agents])
        # reward = -1.0 * (current_makespan + self.deadline * self.num_infeasible / self.num_tasks)
        # reward /= (self.deadline)
        
        # reward = self.get_raw_score()
        # if feasible:
        #     if done:
        #         reward = -1.0 * max([agent.current_makespan for agent in self.agents])
        #     else:
        #         reward = -1.0 * self.agents[agent_id].current_makespan
        #     reward /= self.deadline
        # else:
        #     reward = -3.0

        if done:
            reward = self.get_raw_score()
        elif feasible:
            reward = 1.0
        else:
            reward = -1.0            
            
        # if feasible:
        #     # if done:
        #     current_makespan = max([agent.current_makespan for agent in self.agents])
        #     # else:
        #     #     current_makespan = self.agents[agent_id].current_makespan
        #     reward = -1.0 * (current_makespan / self.deadline)
        # else:
        #     # current_makespan = max([agent.current_makespan for agent in self.agents])
        #     reward = -1.0
        
        # if 'final' in self.reward_mode:
        #     if done:
        #         current_makespan = max([agent.current_makespan for agent in self.agents])
        #         reward = -1.0 * (current_makespan + self.num_infeasible * self.deadline)
        #     else:
        #         reward = 0.0
        # elif feasible:
        #     if done:
        #         current_makespan = max([agent.current_makespan for agent in self.agents])
        #     else:
        #         # current_makespan = self.agents[agent_id].current_makespan
        #         current_makespan = max([agent.current_makespan for agent in self.agents])
        #     if 'base' in self.reward_mode:
        #         reward = -10.0 * (current_makespan - previous_makespan)
        #     elif 'infeasible' in self.reward_mode:
        #         reward = 0.0
        #     elif 'makespan' in self.reward_mode:
        #         reward = -10.0 * (current_makespan + self.deadline * self.num_infeasible / self.num_tasks)
        #     else:
        #         raise NotImplementedError(f"Reward Mode {self.reward_mode} is not implemented")
        # else:
        #     # self.num_infeasible += 1
        #     reward = -30.0 * (self.deadline)
            
        # if 'no_norm' not in self.reward_mode:
        #     reward /= (self.deadline)
            
            # reward = 0
            # reward = -1.0 * self.num_infeasible / self.num_tasks
        
        return observation, reward, done, feasible

    def step_warmstart(self, action):
        """ Solve the environment with the warmstart schedule provided
        Args:
            action (tuple): The action to take in the environment. The action is a tuple of (task_id, agent_id)
        Returns:
            tuple: The observation of the environment, the reward, the done flag, and the info dictionary
        """
        from solvers.milp_solver import warmstart_MILP
        self.schedule.append(action)
        obs, _, done, feasible = self.step(action)
        if len(self.schedule) == len(self.tasks):
            warmstart_output = warmstart_MILP(self, self.schedule)
            feasible = warmstart_output[0]
            if feasible:
                reward = warmstart_output[-1] * -1.0
            else:
                reward = -100.0
                
        else:
            done = False
            reward = 0.0
            
        return obs, reward, done, feasible
 
    def reset(self):
        """Reset the environment to the initial state"""
        [task.reset() for task in self.tasks]
        [agent.reset() for agent in self.agents]
        if self.mode in ['dijkstra']:
            [edge.reset() for edge in self.RRG.shared_edge_set]
        
        self.time_step = 0
        self.num_infeasible = 0
        self.schedule = []
        return self.get_observation(), {}

    def copy(self, source):
        """Copy the environment from the source environment"""
        self.num_agents = source.num_agents
        self.num_tasks = source.num_tasks
        self.num_obstacles = source.num_obstacles
        self.deadline = source.deadline
        self.max_speed = source.max_speed
        self.width = source.width
        self.height = source.height
        self.duration = source.duration
        self.wait_time_constraints = source.wait_time_constraints
        
        self.tasks = [task.copy() for task in source.tasks]
        self.agents = [agent.copy() for agent in source.agents]
        self.obstacles = [obstacle.copy() for obstacle in source.obstacles]
        
        self.initial_travel_distance = source.initial_travel_distance
        self.task_travel_distance = source.task_travel_distance
        self.RRG = source.RRG
        self.shared_paths = source.shared_paths
        
        self._setup_observation_space()
        self._setup_action_space()
        self.time_step = 0
        self.num_infeasible = 0
        
    def deep_copy(self, source):
        """Deep Copy the environment from the source environment"""
        self.num_agents = source.num_agents
        self.num_tasks = source.num_tasks
        self.num_obstacles = source.num_obstacles
        self.deadline = source.deadline
        self.max_speed = source.max_speed
        self.width = source.width
        self.height = source.height
        self.duration = source.duration
        self.wait_time_constraints = source.wait_time_constraints
        
        [task.deep_copy(source.tasks[i]) for i, task in enumerate(self.tasks)]
        [agent.deep_copy(source.agents[i]) for i, agent in enumerate(self.agents)]
        [obstacle.deep_copy(source.obstacles[i]) for i, obstacle in enumerate(self.obstacles)]
        
        self.initial_travel_distance = copy.deepcopy(source.initial_travel_distance)
        self.task_travel_distance = copy.deepcopy(source.task_travel_distance)
        self.RRG = copy.deepcopy(source.RRG)
        self.shared_paths = copy.deepcopy(source.shared_paths)
        
        self._setup_observation_space()
        self._setup_action_space()
        self.time_step = source.time_step
        self.num_infeasible = source.num_infeasible               
    
    def render(self, mode='array'):
        """Render the environment to the screen.
        Use matplotlib to render the environment and save to default location
        Args:
            mode (str): The mode to render the environment. Default is 'array' to return a numpy array. Set 'human' for rendering to the screen """
        if mode in ['human', 'file']:
            plt, _ = self._render_to_human()
            if mode == 'human':
                plt.show()
            elif mode == 'file':
                # TODO: Replace this with a better file saving mechanism in the future
                plt.savefig('figures/scheduling_environment.png')
            
        else:    
            return self._render_to_array()
    
    def get_observation(self):
        """Return the observation of the environment"""
        
        observation = {
            'agent': np.array([agent.get_observation(self.deadline) for agent in self.agents]),
            'task': np.array([task.get_observation(self.deadline) for task in self.tasks]),
            'state': np.array([self.get_state()]),
            'agent_to_task': self._get_agent_task_edge_features(),
            'travel': self._get_task_to_task_travel_distance_edge_features(),
            'wait_time': self._get_wait_time_edge_features(),
            'assigned': self._get_assigned_edge_features(),
            # State edges:
            'agent_to_state': EdgeInstance(np.array([agent.id for agent in self.agents]), np.zeros(self.num_agents), None),
            'task_to_state': EdgeInstance(np.array([task.id for task in self.tasks]), np.zeros(self.num_tasks), None),
            # Agent Task Task Travel Time for Heterogenous Durations
            'task_assignment': self._get_task_assignment(),
            'agent_to_task_assignment': self._get_agent_to_task_assignment(),
            'task_to_task_assignment': self._get_task_to_task_assignment(),
            'task_assignment_to_task_assignment': self._get_task_assignment_to_task_assignment(),
            # 'shared_edge_features': self._get_shared_edge_features()
            'task_to_task_select': self._get_task_to_task_select(),
        }
        return observation
    
    def _get_duration(self, agent_id, task_id):
        """Return the duration of the task for the agent"""
        return self.duration[agent_id * self.num_tasks + task_id]['duration']
    
    def _update_makespan(self, task_id, agent_id):
        """Calculate and Update the current makespan of the agent
        Args:
            task_id (int): The task id
            agent_id (int): The agent id
        Returns:
            bool: True if the task assignment is Feasible, False otherwise
        """
        # makespan = max(reach_time, min_task_start_time) + duration
        current_makespan = self.agents[agent_id].current_makespan
        duration = self._get_duration(agent_id, task_id)
        travel_time = self.get_travel_time(agent_id, task_id)
        
        min_task_start_time = self.tasks[task_id].start_time
        max_task_end_time = self.tasks[task_id].end_time
        
        reach_time = current_makespan + travel_time
        
        # wait time constraints
        earliest_start_times = [reach_time, min_task_start_time]
        feasible = True
        if task_id in self.wait_time_constraints:
            for wait_time in self.wait_time_constraints[task_id]:
                # The prerequisite task is not assigned to any agent
                if self.tasks[wait_time['prerequisite']].assigned_agent is None:
                    feasible = False
                    # return None, reach_time
                # Prerequisite Task is assigned to an agent
                elif not self.tasks[wait_time['prerequisite']].task_will_complete and self.tasks[wait_time['prerequisite']].expected_completion_time is None:
                    feasible = False
                    # return None, reach_time
                else:
                    # check if the wait time dependency is satisfied
                    earliest_allowed_start_time = self.tasks[wait_time['prerequisite']].expected_completion_time + wait_time['wait_time']
                    earliest_start_times.append(earliest_allowed_start_time)
                    
        start_time = max(earliest_start_times)
        end_time = start_time + duration
        if end_time > max_task_end_time:
            feasible = False
            end_time = self.deadline
        self.tasks[task_id].expected_completion_time = end_time
        return end_time, reach_time, feasible
        # # The task assignment is not feasible
        # return None, reach_time
    
    # Edge Feature Functions
    def _get_agent_task_edge_features(self):
        """Return the edge features for the agent-task edges"""
        source_nodes = np.array([agent_id for agent_id in range(self.num_agents) for _ in range(self.num_tasks)])
        task_nodes = np.array([task_id for _ in range(self.num_agents) for task_id in range(self.num_tasks)])
        # [duration, travel_distance]
        edge_features = np.array([
            [
                self._get_duration(agent_id, task_id) / self.deadline,
                self.get_travel_time(agent_id, task_id) / self.deadline
            ] for agent_id in range(self.num_agents)
                for task_id in range(self.num_tasks)
        ])
        return EdgeInstance(source_nodes, task_nodes, edge_features)
    
    @staticmethod
    def get_agent_to_task_feature_size():
        return 2
    
    @staticmethod
    def get_agent_to_task_feature_space():
        return ["Duration", "Travel Time"]
    
    def _get_task_to_task_travel_distance_edge_features(self):
        """Return the edge features for the task to task travel distance"""
        source = np.array([task_id for task_id in range(self.num_tasks) for _ in range(task_id+1, self.num_tasks)])
        target = np.array([other_task_id for task_id in range(self.num_tasks) for other_task_id in range(task_id+1, self.num_tasks)])
        edge_features = np.array([
            [
                self.task_travel_distance[(task_id, other_task_id)] / math.sqrt(self.width ** 2 + self.height ** 2)
            ] for task_id in range(self.num_tasks) for other_task_id in range(task_id+1, self.num_tasks)
        ])
        # Bidirectional Edges
        source_nodes = np.concatenate([source, target])
        target_nodes = np.concatenate([target, source])
        edge_features = np.concatenate([edge_features, edge_features])
        
        return EdgeInstance(source_nodes, target_nodes, edge_features)
    
    def _get_wait_time_edge_features(self):
        """Return the edge features for the wait time edges"""
        # self.wait_time_constraints: dict of {task_id: [ {'prerequisite': task_id, 'wait_time': wait_time}]}
        source_nodes = np.array([wait_time['prerequisite'] for wait_times in self.wait_time_constraints.values() for wait_time in wait_times])
        target_nodes = np.array([task_id for task_id, wait_times in self.wait_time_constraints.items() for _ in wait_times])
        # [wait_time]
        edge_features = np.array([
            [
                wait_time['wait_time'] / self.deadline
            ] for wait_times in self.wait_time_constraints.values() for wait_time in wait_times
        ])
        return EdgeInstance(source_nodes, target_nodes, edge_features)
    
    def _get_assigned_edge_features(self):
        """Return the edge features for the assigned edges"""
        source_nodes = np.array([agent_id for agent_id, agent in enumerate(self.agents) for _ in range(len(agent.assigned_tasks))])
        target_nodes = np.array([task.id for agent in self.agents for task in agent.assigned_tasks])
        return EdgeInstance(source_nodes, target_nodes, None)
    
    def _get_agent_task_task_travel_time(self):
        """Return the travel time between the tasks for each agent"""
        source_nodes = []
        target_nodes = []
        edge_features = []
        for agent_id, _ in enumerate(self.agents):
            for task_id, _ in enumerate(self.tasks):
                for other_task_id in range(task_id+1, len(self.tasks)):
                    source_nodes.append(self.get_task_assignment_id(agent_id, task_id))
                    target_nodes.append(self.get_task_assignment_id(agent_id, other_task_id))
                    edge_features.append(
                        [
                            self.get_travel_time(agent_id, task_id, other_task_id) / self.deadline
                        ]
                    )
        return EdgeInstance(np.array(source_nodes), np.array(target_nodes), np.array(edge_features))

    def _get_task_assignment(self):
        return np.array([
            [
                self._get_duration(agent_id, task_id) / self.deadline
            ] 
            for agent_id, _ in enumerate(self.agents) 
            for task_id, _ in enumerate(self.tasks)
        ])
        
    def _get_agent_to_task_assignment(self):
        source_nodes = np.array([agent_id for agent_id, _ in enumerate(self.agents) for _ in self.tasks])
        target_nodes = np.array([agent_id * self.num_tasks + task_id for agent_id, _ in enumerate(self.agents) for task_id, _ in enumerate(self.tasks)])
        edge_features = np.array([
            [
                self.get_travel_time(agent_id, task_id) / self.deadline
            ] 
            for agent_id, _ in enumerate(self.agents) 
            for task_id, _ in enumerate(self.tasks)
        ])
        return EdgeInstance(source_nodes, target_nodes, edge_features)
    
    def _get_task_to_task_assignment(self):
        source_nodes = np.array([task_id for agent_id, _ in enumerate(self.agents) for task_id, _ in enumerate(self.tasks)])
        target_nodes = np.array([agent_id * self.num_tasks + task_id for agent_id, _ in enumerate(self.agents) for task_id, _ in enumerate(self.tasks)])
        return EdgeInstance(source_nodes, target_nodes, None)
    
    def _get_task_assignment_to_task_assignment(self):
        source_nodes = np.array([agent_id * self.num_tasks + task_id for agent_id, _ in enumerate(self.agents) for task_id, _ in enumerate(self.tasks) for other_task_id, _ in enumerate(self.tasks) if task_id != other_task_id])
        target_nodes = np.array([agent_id * self.num_tasks + other_task_id for agent_id, _ in enumerate(self.agents) for task_id, _ in enumerate(self.tasks) for other_task_id, _ in enumerate(self.tasks) if task_id != other_task_id])
        edge_features = np.array([
            [
                self.get_travel_time(agent_id, task_id, other_task_id) / self.deadline
            ] for agent_id, _ in enumerate(self.agents) for task_id, _ in enumerate(self.tasks) for other_task_id, _ in enumerate(self.tasks) if task_id != other_task_id])
        return EdgeInstance(source_nodes, target_nodes, edge_features)
        
    def _get_shared_edge_features(self):
        """Return the shared edge features for the agents"""
        source_nodes = []
        target_nodes = []
        edge_features = []
        for agent_id, _ in enumerate(self.agents):
            for task_id, _ in enumerate(self.tasks):
                shared_edge_features = self._get_shared_edge_features_of_agent_task(agent_id, task_id)
                if not shared_edge_features:
                    continue
                source_nodes.append(agent_id)
                target_nodes.append(task_id)
                
                edge_features.append(shared_edge_features)
                
        source_nodes = np.array([agent_id for agent_id, _ in enumerate(self.agents) for _ in self.tasks])
        target_nodes = np.array([i for i, _ in enumerate(self.RRG.shared_edge_set)])
        edge_features = np.array([edge.get_feature() for edge in self.RRG.shared_edge_set])
        return EdgeInstance(source_nodes, target_nodes, edge_features)
        
    def _get_task_to_task_select(self):
        source_nodes = np.array([task.id for task in self.tasks if task.assigned_agent is None])    
        target_nodes = np.array([i for i, _ in enumerate(source_nodes)])
        
        return EdgeInstance(source_nodes, target_nodes, None)
    
    def get_task_assignment_id(self, agent_id, task_id):
        """Return the task assignment id from the agent id and the task id"""
        return agent_id * self.num_tasks + task_id
    
    def get_task_agent_id(self, task_assignment_id):
        """Return the agent id and the task id from the task assignment id"""
        return task_assignment_id // self.num_tasks, task_assignment_id % self.num_tasks
    
    # Feature Space Functions
    @staticmethod
    def get_agent_feature_space():
        """Return the feature space for the agents
        (x, y, speed, current_makespan)"""
        return (4,)
    
    @staticmethod
    def get_task_feature_space():
        """Return the feature space for the tasks
        (x, y, start_time, end_time, allocated)"""
        return (4,)
    
    
    @staticmethod
    def get_task_assignment_feature_space():
        """Return the feature space for the agent-task edges
        (duration, travel_time)"""
        return 1
    
    @staticmethod
    def get_agent_to_task_assignment_feature_space():
        """Return the feature space for the agent-task edges
        (duration, travel_time)"""
        return (1,)
    
    @staticmethod
    def get_wait_time_feature_space():
        """Return the feature space for the wait time between the tasks
        (wait_time)"""
        return (1,)
    
    @staticmethod
    def get_agent_task_task_travel_time_feature_space():
        """Return the feature space for the travel time between the tasks for each agent
        (travel_time)"""
        return (1,)
    
    @staticmethod
    def get_task_to_task_travel_feature_space():
        """Return the feature space for the travel distance between the tasks
        (travel_distance)"""
        return (1,)
    
    @staticmethod
    def get_shared_edge_features_space():
        """Return the feature space for the shared edges in the RRG
        (start_time, duration, end_time)"""
        return (3,)
        
    def get_metadata(self):
        return {
            "num_agents": self.num_agents,
            "num_tasks": self.num_tasks,
            "num_obstacles": self.num_obstacles,
            "max_speed": self.max_speed,
            "max_deadline": self.deadline,
            "map": {
                "width": self.width,
                "height": self.height
            }
        }
    
    def get_state(self):
        # Return the state of the environment
        ## [num_agents, num_tasks, num_assigned_tasks, num_unassigned_tasks, max_makespan]
        return np.array([
            # len(self.agents),
            # len(self.tasks),
            sum(len(agent.assigned_tasks) for agent in self.agents) / len(self.tasks),
            # len(self.tasks) - sum(len(agent.assigned_tasks) for agent in self.agents),
            self._get_max_makespan() / self.deadline
        ])
    
    @staticmethod
    def get_state_features():
        return [
            "Assigned Task Percentage",
            "Maximum Makespan"
        ]
    
    @staticmethod
    def get_state_feature_space():
        return 2
    
    @staticmethod
    def get_agent_task_feature_space():
        return (2,)
    
    def _generate_feature_space(self, min=0, max=1, dim=2):
        if dim == 1:
            return gym.spaces.Discrete(max - min + 1, start=min, stop=max)
        else:
            return gym.spaces.Box(low=min, high=max, shape=(dim,))
        
    def _get_max_makespan(self):
        makespans = [agent.current_makespan for agent in self.agents if agent.assigned_tasks and agent.current_makespan < self.deadline]
        if not makespans:
            return 0.0
        return max(makespans)
    
    def get_raw_score(self):
        makespan = self._get_max_makespan()
        num_unassigned = self.num_infeasible + (self.num_tasks - self.time_step)
        if num_unassigned == 0:
            return 1.0 - makespan / self.deadline + self.num_tasks
        return float(self.num_tasks - num_unassigned + (1.0 - makespan / self.deadline))
        # return (-1.0 * makespan / self.deadline - num_unassigned -1) # / self.num_tasks
    
    # Path Planning Functions 
    def _get_path_planning(self, filename = None):
        """Generate the RRG and the shared paths between the plans
        Path Plan Name is the same as the filename, with the extension "_path.json"
        Args:
            filename (str): The filename to save the path plan. If None, then the path plan is generated.
        
        """
        if self.mode in ['dijkstra', 'rrg']:
            if filename:
                base_filename = os.path.splitext(filename)[0]
                path_filename = base_filename + '_path.json'
                self.RRG = RRG(self, path_filename) if os.path.exists(path_filename) else RRG(self)
            else:
                self.RRG = RRG(self)
            if not self.RRG.fully_connected:
                raise Exception("RRG is not fully connected, recommend generating a new problem.")
            self.shared_paths = self.RRG.find_shared_edges()
            all_shared_edges = {edge for path in self.shared_paths for edge in path}
        elif self.mode in ['euclidean', 'manhattan']:
            self.RRG = None
            self.shared_paths = None
            all_shared_edges = None
        
    def _get_path_to_task(self, agent_id, task_id):
        """Get the path to the task from the agent"""
        if self.agents[agent_id].assigned_tasks:
            return self.RRG.task_to_task_paths[(self.agents[agent_id].assigned_tasks[-1], task_id)]
        else:
            return self.RRG.agent_to_task_paths[(agent_id, task_id)]
        
    def _get_shared_edge_indices(self, agent_id, task_id):
        path = self._get_path_to_task(agent_id, task_id)
        return [self.RRG.shared_edge_set.index(edge) for edge in path if edge in self.RRG.shared_edge_set]
    
    def _get_shared_edge_features_of_agent_task(self, agent_id, task_id):
        indices = self._get_shared_edge_indices(agent_id, task_id)
        if not indices:
            return np.array([])
        
        
        return np.array([self.RRG.shared_edge_set[i].get_feature() for i in indices])
    
    def _render_to_human(self, fig = None, ax = None, legends_lines = [], labels = []):
        """Render the environment to the screen using Matplotlib"""
        if fig is None and ax is None:
            fig, ax = plt.subplots()
        
        # Draw the agents as blue triangles
        for agent in self.agents:
            triangle = patches.RegularPolygon((agent.location.x, agent.location.y), numVertices=3, radius=1, orientation=0, color='blue')
            ax.add_patch(triangle)

        # Draw the tasks as red squares
        for task in self.tasks:
            square = patches.Rectangle((task.location.x-1, task.location.y-1), 2, 2, color='red')
            ax.add_patch(square)

        # Draw the obstacles as grey circles
        for obstacle in self.obstacles:
            circle = patches.Circle((obstacle.location.x, obstacle.location.y), obstacle.radius, color='grey')
            ax.add_patch(circle)
            
        ax.set_xlim(0, self.width)
        ax.set_ylim(0, self.height)
        ax.set_aspect('equal')
        
        legends = [
            lines.Line2D([0], [0], marker='^', color='blue', lw=0, label='Agent'),
            lines.Line2D([0], [0], marker='s', color='red', lw=0, label='Task'),
            lines.Line2D([0], [0], marker='o', color='grey', lw=0, label='Obstacle')
        ]
        legends.extend(legends_lines)
        # move legend to the upper right of outside the plot
        ax.legend(legends, ['Agent', 'Task', 'Obstacle'] + labels, loc='upper left', bbox_to_anchor=(1.04, 1))
        
        return plt, ax
    
    def _render_to_array(self):
        """Render the environment to a numpy array"""
        # Render the environment using the location of the agents, tasks and obstacles
        map = np.zeros((self.width, self.height))
        for agent in self.agents:
            map[agent.location.x, agent.location.y] = 1
        for task in self.tasks:
            map[task.location.x, task.location.y] = 2
        for obstacle in self.obstacles:
            map[obstacle.location.x, obstacle.location.y] = 3
            # set the area within the radius of the obstacle center
            # Limit the search into a square area
            for i in range(obstacle.location.x - obstacle.radius, obstacle.location.x + obstacle.radius):
                for j in range(obstacle.location.y - obstacle.radius, obstacle.location.y + obstacle.radius):
                    # check if the point is within the radius of the obstacle using distance from location
                    if obstacle.location.distance(Location(i, j)) <= obstacle.radius:
                        map[i, j] = 3  
        return map 
        
    # Static Methods for Problem Generator
    @staticmethod
    def _generate_random_config(generator_config, obstacles = None):
        config = {}        
        if generator_config['num_agents'][0] == generator_config['num_agents'][1]:
            config['num_agents'] = generator_config['num_agents'][0]
        else:
            config['num_agents'] = np.random.randint(generator_config['num_agents'][0], generator_config['num_agents'][1])[0]
            
        if generator_config['num_tasks'][0] == generator_config['num_tasks'][1]:
            config['num_tasks'] = generator_config['num_tasks'][0]
        else:
            config['num_tasks'] = np.random.randint(generator_config['num_tasks'][0], generator_config['num_tasks'][1])[0]
        config['max_speed'] = generator_config['agent_config']['speed']['max']
        
        config['map'] = {
            'width': generator_config['map']['width'],
            'height': generator_config['map']['height']
        }
        
        config['agents'] = [
            {
                'id' : id,
                'x' : np.random.randint(0, config['map']['width']),
                'y' : np.random.randint(0, config['map']['height']),
                'speed' : np.random.randint(generator_config['agent_config']['speed']['min'], generator_config['agent_config']['speed']['max'])
            } for id in range(config['num_agents'])
        ]
        
        # max deadline = (max travel time + max task duration) * num_tasks
        config['max_deadline'] = (math.sqrt(config['map']['width'] ** 2 + config['map']['height'] ** 2) / generator_config['agent_config']['speed']['min'] + generator_config['task_config']['duration']['max']) * config['num_tasks']
        
        config['tasks'] = []
        for id in range(config['num_tasks']):
            
            task = {
                'id' : id,
                'x' : np.random.randint(0, config['map']['width']),
                'y' : np.random.randint(0, config['map']['height'])
            }
            # time window percentage is between min_percentage and max_percentage
            time_window_percentage = np.random.random() * (generator_config['task_config']['time_window_percentage'][1] - generator_config['task_config']['time_window_percentage'][0]) + generator_config['task_config']['time_window_percentage'][0]
            # start time is between 0 and (1 - time_window_percentage) * max_deadline
            if time_window_percentage == 1.0:
                task['start_time'] = 0.0
                task['end_time'] = config['max_deadline']
            else:
                task['start_time'] = np.random.randint(0, config['max_deadline'] * (1.0 - time_window_percentage))
                # time window is between time_window_percentage * max_deadline and max_deadline
                time_window = int(config['max_deadline'] * time_window_percentage)
                
                task['end_time'] = task['start_time'] + time_window
            config['tasks'].append(task)    

        # Heterogenous Durations for Task Agent Pairing
        
        config['durations'] = [
                {
                    'agent': i,
                    'task': j,
                    'duration': np.random.randint(generator_config['task_config']['duration']['min'], generator_config['task_config']['duration']['max'])
                } for i in range(config['num_agents']) for j in range(config['num_tasks'])
            ]
        
        config['wait_time_percentage'] = np.random.rand() * (generator_config['task_config']['wait_time_percentage'][1] - generator_config['task_config']['wait_time_percentage'][0]) + generator_config['task_config']['wait_time_percentage'][0]
        wait_time_percentage = np.random.rand() * (generator_config['task_config']['wait_time_percentage'][1] - generator_config['task_config']['wait_time_percentage'][0]) + generator_config['task_config']['wait_time_percentage'][0]
        config['num_wait_time'] = config['num_tasks'] * wait_time_percentage
        
        
        if generator_config['task_config']['wait_time_percentage'][0] == generator_config['task_config']['wait_time_percentage'][1]:
            wait_time_percentage = generator_config['task_config']['wait_time_percentage'][0]
        else:
            wait_time_percentage = np.random.rand() * (generator_config['task_config']['wait_time_percentage'][1] - generator_config['task_config']['wait_time_percentage'][0]) + generator_config['task_config']['wait_time_percentage'][0]
        
        
        config['wait_time'] = {}
        for _ in range(int(wait_time_percentage * config['num_tasks'])):
            task_id = int(np.random.randint(0, config['num_tasks']))
            
            prerequisite_task = int(np.random.choice([j for j in range(config['num_tasks']) if j != task_id]))
            
            if task_id not in config['wait_time']:
                config['wait_time'][task_id] = []
            
            config['wait_time'][task_id].append(
                {
                    'prerequisite': prerequisite_task,
                    'wait_time': np.random.randint(generator_config['task_config']['wait_time_duration'][0], generator_config['task_config']['wait_time_duration'][1])
                }
            )
            
        if obstacles is None:
            if generator_config['obstacle_config']['number'][0] == generator_config['obstacle_config']['number'][1]:
                config['num_obstacles'] = generator_config['obstacle_config']['number'][0]
            else:
                config['num_obstacles'] = np.random.randint(generator_config['obstacle_config']['number'][0], generator_config['obstacle_config']['number'][1])
                
            config['obstacles'] = SchedulingEnvironment._generate_obstacles(config, generator_config)
        else:
            config['num_obstacles'] = len(obstacles)
            # config['obstacles'] = obstacles
        return config
     
    @staticmethod
    def _generate_obstacles(config, generator_config):
        """Generate the obstacles for the environment
        Check if the obstacles are not overlapping with the agents and tasks and are within the map
        Args:
            config (dict): The configuration of the environment
            generator_config (dict): The configuration of the generator
        Returns:
            list: The list of obstacles for the environment
        """
        obstacles = []
        while len(obstacles) < config['num_obstacles']:
            # Generate the obstacle with random location and radius
            obstacle = Obstacle(
                np.random.randint(0, config['map']['width']), 
                np.random.randint(0, config['map']['height']), 
                np.random.randint(
                        generator_config['obstacle_config']['radius'][0], 
                        generator_config['obstacle_config']['radius'][1]
                    )
                )
            # Check if the obstacle is not overlapping with the agents and tasks
            feasible = True
            for obj_type in ['agents', 'tasks']:
                for obj in config[obj_type]:
                    if obstacle.location.distance(Location(obj['x'], obj['y'])) < obstacle.radius:
                        feasible = False
                        break
                if not feasible:
                    break
            # If the obstacle is covering an agent or a task, then it is not feasible
            if not feasible:
                continue
            obstacles.append(obstacle.__dict__())
        return obstacles
    
    def solve(self):
        """Solve the environment using the MILP Solver"""
        if self.verbose:
            print("Solving the Environment using MILP Solver")
        self.reset()
        feasible, _, _, schedule, _ = solve_with_MILP(self)
        print(f"...Solved the Environment with Feasibility: {feasible}, Schedule: {schedule}...")
        return feasible, schedule
    
    @staticmethod
    def combine_problems(save_paths, target_path, mode):
        """Combine all the problems in the save path into a single larger problem with and store it in the target path"""
        combined_env = SchedulingEnvironment(target_path, mode=mode, merged=True)
        combined_env.tasks = []
        combined_env.agents = []
        combined_env.obstacles = []
        combined_env.duration = []
        combined_env.wait_time_constraints = {}
        schedule = []
        for save_path in save_paths:
            sub_env = SchedulingEnvironment(config_file=save_path, mode=mode)
            sub_schedule = sub_env._get_schedule_from_file()
            task_map = {}          
            for task in sub_env.tasks:
                task_map[task.id] = len(combined_env.tasks)
                task.id = len(combined_env.tasks)
                combined_env.tasks.append(task)
            agent_map = {}
            for agent in sub_env.agents:
                agent_map[agent.id] = len(combined_env.agents)
                agent.id = len(combined_env.agents)
                combined_env.agents.append(agent)
                
            # only add the obstacles from one of the environments
            if not combined_env.obstacles:
                combined_env.obstacles = sub_env.obstacles
                
            for key, value in sub_env.wait_time_constraints.items():
                combined_env.wait_time_constraints[task_map[key]] = [
                    {
                        'prerequisite': task_map[v['prerequisite']],
                        'wait_time': v['wait_time']
                    } for v in value
                ]
            # durations map
            for element in sub_env.duration:
                combined_env.duration.append({
                    'agent': agent_map[element['agent']],
                    'task': task_map[element['task']],
                    'duration': element['duration']
                })
            
            for task_id, agent_id in sub_schedule:
                schedule.append((task_map[task_id], agent_map[agent_id]))
        # add random durations for missing pairings in the duration map
        for agent_id in range(len(combined_env.agents)):
            for task_id in range(len(combined_env.tasks)):
                # if the agent-task pairing is not in the same sub_env
                if agent_id // sub_env.num_agents != task_id // sub_env.num_tasks:
                    combined_env.duration.append({
                        'agent': agent_id,
                        'task': task_id,
                        'duration': np.random.randint(10, 100) # TODO: replace this with actual
                    })
        
        combined_env.num_agents = len(combined_env.agents)
        combined_env.num_tasks = len(combined_env.tasks)
        combined_env.num_obstacles = len(combined_env.obstacles)
        
        combined_env._get_path_planning(target_path)
        combined_env._generate_travel_times()
        combined_env._setup_observation_space()
        combined_env._setup_action_space()
        combined_env.time_step = 0
        combined_env.num_infeasible = 0
        combined_env.deadline = sub_env.deadline

        combined_env.max_speed = sub_env.max_speed
        combined_env.width = sub_env.width
        combined_env.height = sub_env.height


        print(schedule)
        # check the feasibility of the combined environment
        # NOTE: The Solver is having problem with this stage, due to the size of the MILP Solver. Since the Combined schedule is feasible, we will use that as the solution from the optimal solver. 
        # feasible, schedule = combined_env.solve()
        feasible =  True
        if not feasible:
            raise 
        schedule_path = combined_env.get_schedule_location()
        with open(schedule_path, 'w') as f:
            # write the schedule into the file
            f.write(", ".join([str(s) for s in schedule]))
        
        combined_env.save()
        return combined_env
    
def generate_environment_config(filename):
    config = {
        'num_agents': [5, 5],
        'num_tasks': [10, 10],
        'map': {
            'width': 100,
            'height': 100
        },
        'task_config': {
            'duration': {
                'min': 10,
                'max': 100
            },
            'time_window_percentage': [0.2, 0.8],
            'wait_time_percentage': [0.5, 0.5],
            'wait_time_duration': [10, 100]
        },
        'agent_config': {
            'speed': {
                'min': 1,
                'max': 10
            }
        },
        'obstacle_config': {
            'number': [5, 10],
            'radius': [5, 25]
        },
    }
    # save the config into json file
    with open(filename, 'w') as f:
        json.dump(config, f, indent=4)

def generate_environment(config_file, save_location = None):
    env = SchedulingEnvironment(generator_config = config_file)
    if save_location is not None:
        env.save(save_location)
    return env

def make_env(env_location, env_id, min=1, max=2000, mode='euclidean', reward_mode='base'):
    # TODO: add the scheduling environment
    # save location is "test/sample_problem_5_10_problem_XXXXX.json" where XXXXX is the env_id padded with 0s
    id = ((env_id - min) % (max - min + 1)) + min
    save_location = f"{env_location}/problem_{str(id).zfill(5)}.json"
    # print(f"Loading Environment: {save_location}")
    env = SchedulingEnvironment(save_location, mode=mode, reward_mode='base')
    return env

class Environments():
    def __init__(self, env_location, mode='euclidean') -> None:
        self.envs = [SchedulingEnvironment(env) for env in env_location]
        
    def reset(self):
        return [env.reset() for env in self.envs]
    
