"""
author: Anonymous

Earliest Deadline First (EDF) is a scheduling algorithm that assigns tasks to agents based on the earliest deadline. The Improvement is that the agent will not be assigned to a task if the prerequisite task is not completed. This is a simple heuristic improvement that is meant to check how effective the learning models are.
"""
import sys
sys.path.append("..")

import numpy as np
import torch

from scheduling.environment import SchedulingEnvironment
from scheduling.agent import Agent
from scheduling.task import Task

class ImprovedEarliestDeadlineFirstAgent():
    def __init__(self):
        self.name = "Improved EDF"
        
    def set_environment(self, environment, rerun=False):
        pass
        
    def get_action(self, observation, greedy = True):
        """
        Args:
            observation (dgl.DGLGraph): Observation graph
        Returns:
            int: Action
        """
        agent_times = observation['agent'][:, Agent.agent_current_makespan_index()]
        agent_id = np.argmin(agent_times)
        # EdgeInstance: agent_to_task_assignments with source_nodes, target_nodes, edge_features
        agent_to_task_assignments = observation['agent_to_task_assignment']
        # indices of the source nodes that are equal to agent id
        task_indices = np.where(agent_to_task_assignments.source_nodes == agent_id)[0]
        # the target nodes of the tasks that are assigned to the agent
        tasks = agent_to_task_assignments.target_nodes[task_indices]
        
        unassigned_tasks = observation['task_to_task_select'].source_nodes
        prerequisite_tasks = observation['wait_time'].source_nodes
        dependent_tasks = observation['wait_time'].target_nodes
        # remove the dependent tasks from the unassigned tasks if the prerequisite is in the unassigned tasks
        feasible_tasks = [task for task in unassigned_tasks if task not in dependent_tasks or prerequisite_tasks[np.where(dependent_tasks == task)[0][0]] not in unassigned_tasks]
        
        task_assignment_indices = tasks[feasible_tasks]
        # the edge features of the tasks that are assigned to the agent
        travel_time = agent_to_task_assignments.edge_features[task_assignment_indices]
        duration = observation['task_assignment'][task_assignment_indices]
        expected_completion_time = travel_time + duration
        # the index of the task with the earliest deadline
        task_assignment_id = np.argmin(expected_completion_time)
        task_id = feasible_tasks[task_assignment_id]
        
        return (task_id, agent_id), None, None, None
    
    def get_agent_probs(self, observation, greedy=True, adaptive_temperature=False):
        """Get agent probabilities
        Args:
            x (dict): Input data Observation
            greedy (bool): Greedy action
            adaptive_temperature (bool): Adaptive temperature
        Returns:
            tuple: Tuple of agent ID, agent probabilities, agent log probabilities
        """
        agent_times = observation['agent'][:, Agent.agent_current_makespan_index()]
        agent_id = np.argmin(agent_times)
        agent_probs = np.zeros(len(agent_times))
        agent_probs[agent_id] = 1.0
        log_probs = np.log(agent_probs + 1e-8)
        return agent_id, torch.tensor([agent_probs]), log_probs
    
    def get_task_probs(self, observation, agent_id, greedy=True, adaptive_temperature=False):
        """Get task probabilities
        Args:
            x (dict): Input data Observation
            agent_id (int): Agent ID
            greedy (bool): Greedy action
            adaptive_temperature (bool): Adaptive temperature
        Returns:
            tuple: Tuple of task ID, task probabilities, task log probabilities
        """
        agent_to_task_assignments = observation['agent_to_task_assignment']
        task_indices = np.where(agent_to_task_assignments.source_nodes == agent_id)[0]
        tasks = agent_to_task_assignments.target_nodes[task_indices]
        
        unassigned_tasks = observation['task_to_task_select'].source_nodes
        prerequisite_tasks = observation['wait_time'].source_nodes
        dependent_tasks = observation['wait_time'].target_nodes
        # remove the dependent tasks from the unassigned tasks if the prerequisite is in the unassigned tasks
        feasible_tasks = [task for task in unassigned_tasks if task not in dependent_tasks or prerequisite_tasks[np.where(dependent_tasks == task)[0][0]] not in unassigned_tasks]
        
        task_assignment_indices = tasks[feasible_tasks]
        travel_time = agent_to_task_assignments.edge_features[task_assignment_indices]
        duration = observation['task_assignment'][task_assignment_indices]
        expected_completion_time = travel_time + duration
        task_assignment_id = np.argmin(expected_completion_time)
        task_id = feasible_tasks[task_assignment_id]
        
        task_probs = np.zeros(len(unassigned_tasks))
        t2ta_id = np.where(unassigned_tasks == task_id)[0][0]
        task_probs[t2ta_id] = 1.0
        log_probs = np.log(task_probs + 1e-8)
        return task_id, torch.tensor([task_probs]), log_probs
