
import numpy as np
from agent import Agent


class FairSchedulingAgent(Agent):
    # Fair Scheduling algorithm - ensures each job gets fair share of resources
    # scheduling complexity: O(num_jobs * num_nodes)
    def __init__(self, min_share=1):
        Agent.__init__(self)
        self.min_share = min_share  # Minimum executors each job should get
        
    def get_action(self, obs):
        # parse observation
        job_dags, source_job, num_source_exec, \
        frontier_nodes, executor_limits, \
        exec_commit, moving_executors, action_map = obs
        
        # Safety check for empty frontier_nodes
        if not frontier_nodes:
            return None, num_source_exec
            
        # Calculate current allocation per job
        current_allocation = {}
        for job_dag in job_dags:
            current_allocation[job_dag] = len(job_dag.executors)
            
        # Count executors in transit
        for node in moving_executors.moving_executors.values():
            if node.job_dag in current_allocation:
                current_allocation[node.job_dag] += 1
                
        # Count committed executors
        for s in exec_commit.commit:
            if s is None:
                continue
            for n in exec_commit.commit[s]:
                if n is not None and n.job_dag in current_allocation:
                    current_allocation[n.job_dag] += exec_commit.commit[s][n]
        
        # Calculate fair share
        num_jobs = len(job_dags)
        if num_jobs == 0:
            return None, num_source_exec
        
        total_executors = sum(current_allocation.values())
        fair_share = max(getattr(self, 'min_share', 1),  # 使用默认值1如果min_share不存在
                         int(total_executors / num_jobs))
        
        # First try to assign executor to the same job if possible (locality optimization)
        if source_job is not None:
            # immediately schedulable nodes with task duration consideration
            source_job_frontier = [node for node in source_job.frontier_nodes if node in frontier_nodes]
            if source_job_frontier:
                # Prioritize nodes with more tasks and on critical path
                selected_node = max(source_job_frontier, 
                                   key=lambda n: (n.num_tasks, getattr(n, 'get_node_duration()', 0)))
                return selected_node, num_source_exec
                
            # schedulable node in the job
            source_job_nodes = [node for node in frontier_nodes if node.job_dag == source_job]
            if source_job_nodes:
                selected_node = max(source_job_nodes, 
                                   key=lambda n: (n.num_tasks, getattr(n, 'get_node_duration()', 0)))
                return selected_node, num_source_exec
        
        # Find the job that's most below its fair share
        job_deficits = {job: fair_share - current_allocation[job] for job in job_dags}
        
        # Consider executor limits if specified
        for job in job_deficits:
            job_name = getattr(job, 'name', str(job))
            if job_name in executor_limits and executor_limits[job_name] is not None:
                # Adjust deficit based on executor limit
                if current_allocation[job] >= executor_limits[job_name]:
                    job_deficits[job] = 0  # Already at limit
        
        sorted_jobs = sorted(job_deficits.keys(), 
                            key=lambda job: (job_deficits[job], 
                                            sum(getattr(n, 'num_tasks', 0) for n in job.frontier_nodes)),
                            reverse=True)  # Most under-resourced and largest workload first
        
        # Allocate to the most under-resourced job first
        for job_dag in sorted_jobs:
            if job_deficits[job_dag] <= 0:
                # This job has its fair share already
                continue
                
            # Find schedulable nodes for this job
            job_frontier = [node for node in job_dag.frontier_nodes if node in frontier_nodes]
            
            next_node = None
            if job_frontier:
                # Select node with highest task count and duration
                next_node = max(job_frontier, 
                               key=lambda n: (n.num_tasks, getattr(n, 'get_node_duration()', 0)))
            else:
                # Look for any schedulable node in this job
                job_nodes = [node for node in frontier_nodes if node.job_dag == job_dag]
                if job_nodes:
                    next_node = max(job_nodes, 
                                   key=lambda n: (n.num_tasks, getattr(n, 'get_node_duration()', 0)))
            
            # Node is selected, compute limit
            if next_node is not None:
                # Calculate available tasks for this node
                committed = getattr(exec_commit, 'node_commit', {}).get(next_node, 0)
                moving = getattr(moving_executors, 'count', lambda n: 0)(next_node)
                
                available_tasks = max(0, next_node.num_tasks - next_node.next_task_idx - committed - moving)
                
                use_exec = min(
                    available_tasks,
                    job_deficits[job_dag],  # Only allocate up to fair share
                    num_source_exec
                )
                
                if use_exec > 0:
                    return next_node, use_exec
        
        # If all jobs have their fair share, distribute remaining executors based on task count
        # Sort frontier nodes by task count and node duration
        sorted_frontier = sorted(frontier_nodes, 
                               key=lambda n: (n.num_tasks, getattr(n, 'get_node_duration()', 0)), 
                               reverse=True)
        
        for node in sorted_frontier:
            # Calculate available tasks for this node
            committed = getattr(exec_commit, 'node_commit', {}).get(node, 0)
            moving = getattr(moving_executors, 'count', lambda n: 0)(node)
            
            available_tasks = max(0, node.num_tasks - node.next_task_idx - committed - moving)
            
            if available_tasks > 0:
                use_exec = min(available_tasks, num_source_exec)
                return node, use_exec
        
        # There are more executors than tasks in the system
        return None, num_source_exec