
import numpy as np
from agent import Agent


class SJFAgent(Agent):
    # Shortest Job First scheduling algorithm
    # scheduling complexity: O(num_jobs * num_nodes)
    def __init__(self):
        Agent.__init__(self)
        
    def get_action(self, obs):
        # parse observation
        job_dags, source_job, num_source_exec, \
        frontier_nodes, executor_limits, \
        exec_commit, moving_executors, action_map = obs
        
        # first assign executor to the same job if possible
        if source_job is not None:
            # immediately schedulable nodes
            for node in source_job.frontier_nodes:
                if node in frontier_nodes:
                    return node, num_source_exec
            # schedulable node in the job
            for node in frontier_nodes:
                if node.job_dag == source_job:
                    return node, num_source_exec
        
        # Calculate estimated job duration (use total remaining tasks as a proxy)
        jobs_remaining_work = {}
        for job_dag in job_dags:
            remaining_tasks = sum([
                max(0, node.num_tasks - node.next_task_idx - 
                    exec_commit.node_commit[node] - 
                    moving_executors.count(node))
                for node in job_dag.nodes
            ])
            jobs_remaining_work[job_dag] = remaining_tasks
        
        # Sort jobs by remaining work (shortest first)
        sorted_jobs = sorted(jobs_remaining_work.keys(), 
                            key=lambda job: jobs_remaining_work[job])
        
        # Assign executors to shortest job first
        for job_dag in sorted_jobs:
            next_node = None
            # immediately schedulable node first
            for node in job_dag.frontier_nodes:
                if node in frontier_nodes:
                    next_node = node
                    break
            # then schedulable node in the job
            if next_node is None:
                for node in frontier_nodes:
                    if node in job_dag.nodes:
                        next_node = node
                        break
            # node is selected, compute limit
            if next_node is not None:
                use_exec = min(
                    node.num_tasks - node.next_task_idx - 
                    exec_commit.node_commit[node] - 
                    moving_executors.count(node),
                    num_source_exec)
                return node, use_exec
        
        # there is more executors than tasks in the system
        return None, num_source_exec