"""
author: Anonymous
"""

import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)

import numpy as np
import torch
import torch.nn as nn

from scheduling.environment import SchedulingEnvironment
from scheduling.agent import Agent
from scheduling.task import Task

class EarliestDeadlineFirstAgent():
    def __init__(self):
        self.name = "EDF"
        
    def set_environment(self, environment, rerun=False):
        pass
    
    def get_action(self, observation, greedy=False):
        """
        Args:
            observation (dgl.DGLGraph): Observation graph
        Returns:
            int: Action
        """
        agent_times = observation['agent'][:, Agent.agent_current_makespan_index()]
        agent_id = np.argmin(agent_times)
        # EdgeInstance: agent_to_task_assignments with source_nodes, target_nodes, edge_features
        agent_to_task_assignments = observation['agent_to_task_assignment']
        # indices of the source nodes that are equal to agent id
        task_indices = np.where(agent_to_task_assignments.source_nodes == agent_id)[0]
        # the target nodes of the tasks that are assigned to the agent
        tasks = agent_to_task_assignments.target_nodes[task_indices]
        
        unassigned_tasks = observation['task_to_task_select'].source_nodes
        task_assignment_indices = tasks[unassigned_tasks]
        # the edge features of the tasks that are assigned to the agent
        travel_time = agent_to_task_assignments.edge_features[task_assignment_indices]
        duration = observation['task_assignment'][task_assignment_indices]
        expected_completion_time = travel_time + duration
        # the index of the task with the earliest deadline
        task_assignment_id = np.argmin(expected_completion_time)
        task_id = unassigned_tasks[task_assignment_id]
        
        return (task_id, agent_id), None, None, None
        
    def get_agent(self, observation, greedy=False):
        """
        Args:
            observation (dgl.DGLGraph): Observation graph
        Returns:
            int: Agent
        """
        agent_times = observation['agent'][:, Agent.agent_current_makespan_index()]
        agent_id = np.argmin(agent_times)
        return agent_id, None, None, None
        
    def get_agent_probs(self, observation, greedy=False, adaptive_temperature=False):
        """Get agent probabilities
        Args:
            x (dict): Input data Observation
            greedy (bool): Greedy action
            adaptive_temperature (bool): Adaptive temperature
        Returns:
            tuple: Tuple of agent ID, agent probabilities, agent log probabilities
        """
        agent_times = observation['agent'][:, Agent.agent_current_makespan_index()]
        agent_id = np.argmin(agent_times)
        agent_probs = np.zeros(len(agent_times))
        agent_probs[agent_id] = 1.0
        log_probs = np.log(agent_probs + 1e-8)
        return agent_id, torch.tensor([agent_probs]), log_probs
        
    def get_task_probs(self, observation, agent_id, greedy=True):
        """Get task probabilities
        Args:
            x (dict): Input data Observation
            agent_id (int): Agent ID
            greedy (bool): Greedy action
        Returns:
            tuple: Tuple of task ID, task probabilities, task log probabilities
        """
        # EdgeInstance: agent_to_task_assignments with source_nodes, target_nodes, edge_features
        agent_to_task_assignments = observation['agent_to_task_assignment']
        # indices of the source nodes that are equal to agent id
        task_indices = np.where(agent_to_task_assignments.source_nodes == agent_id)[0]
        # the target nodes of the tasks that are assigned to the agent
        tasks = agent_to_task_assignments.target_nodes[task_indices]
        
        unassigned_tasks = observation['task_to_task_select'].source_nodes
        task_assignment_indices = tasks[unassigned_tasks]
        # the edge features of the tasks that are assigned to the agent
        travel_time = agent_to_task_assignments.edge_features[task_assignment_indices]
        duration = observation['task_assignment'][task_assignment_indices]
        expected_completion_time = travel_time + duration
        # the index of the task with the earliest deadline
        task_assignment_id = np.argmin(expected_completion_time)
        task_id = unassigned_tasks[task_assignment_id]
        task_probs = np.zeros(len(unassigned_tasks))
        task_probs[task_assignment_id] = 1.0
        log_probs = np.log(task_probs + 1e-8)
        return task_id, torch.tensor([task_probs]), log_probs
        
class EarliestDeadlineFirstAgentScheduler():
    def __init__(self):
        # super().init()
        self.name = "EDF_Agent"
        
    def forward(self, observation, memory=False):
        agent_times = observation['agent'][:, Agent.agent_current_makespan_index()]
        agent_id = np.argmin(agent_times)
        probability = np.zeros(len(agent_times))
        probability[agent_id] = 1.0
        return torch.tensor(probability), None, None, None
        # return agent_id, None, None, None
    
    def replay(self, g, nf, ef, mode):
        return None