
import numpy as np
from agent import Agent


class SRTFAgent(Agent):
    # Shortest Remaining Time First scheduling algorithm
    # scheduling complexity: O(num_jobs * num_nodes)
    def __init__(self):
        Agent.__init__(self)
        
    def get_action(self, obs):
        # parse observation
        job_dags, source_job, num_source_exec, \
        frontier_nodes, executor_limits, \
        exec_commit, moving_executors, action_map = obs
        
        # Calculate remaining time for each job (using remaining tasks as proxy)
        job_remaining_times = {}
        for job_dag in job_dags:
            remaining_tasks = sum([
                max(0, node.num_tasks - node.next_task_idx - 
                    exec_commit.node_commit[node] - 
                    moving_executors.count(node))
                for node in job_dag.nodes
            ])
            # Consider current executors when estimating remaining time
            num_executors = len(job_dag.executors) + 1  # +1 to avoid division by zero
            job_remaining_times[job_dag] = remaining_tasks / num_executors
            
        # Sort jobs by remaining time (shortest first)
        sorted_jobs = sorted(job_remaining_times.keys(), 
                            key=lambda job: job_remaining_times[job])
        
        # Preemptively reassign executors to job with shortest remaining time
        for job_dag in sorted_jobs:
            next_node = None
            # immediately schedulable node first
            for node in job_dag.frontier_nodes:
                if node in frontier_nodes:
                    next_node = node
                    break
            # then schedulable node in the job
            if next_node is None:
                for node in frontier_nodes:
                    if node in job_dag.nodes:
                        next_node = node
                        break
            # node is selected, compute limit
            if next_node is not None:
                use_exec = min(
                    node.num_tasks - node.next_task_idx - 
                    exec_commit.node_commit[node] - 
                    moving_executors.count(node),
                    num_source_exec)
                return node, use_exec
        
        # there is more executors than tasks in the system
        return None, num_source_exec