
import numpy as np
from agent import Agent
from collections import defaultdict


class MultilevelFeedbackAgent(Agent):
    # Multilevel Feedback Queue scheduling algorithm
    # scheduling complexity: O(num_jobs * num_nodes)
    def __init__(self, num_queues=3, time_quantum_base=1):
        Agent.__init__(self)
        self.num_queues = num_queues
        self.time_quantum_base = time_quantum_base
        
        # Initialize queues with different priorities
        self.queues = [[] for _ in range(num_queues)]
        
        # Track job information
        self.job_queue_mapping = {}  # Which queue a job is in
        self.job_usage_stats = {}    # Track resource usage to determine demotion
        
    def get_action(self, obs):
        # parse observation
        job_dags, source_job, num_source_exec, \
        frontier_nodes, executor_limits, \
        exec_commit, moving_executors, action_map = obs
        
        # Update job queues with new jobs (new jobs go to highest priority queue)
        for job_dag in job_dags:
            if job_dag not in self.job_queue_mapping:
                self.queues[0].append(job_dag)
                self.job_queue_mapping[job_dag] = 0
                self.job_usage_stats[job_dag] = 0
                
        # Remove finished jobs
        current_jobs = set(job_dags)
        for queue_idx in range(self.num_queues):
            self.queues[queue_idx] = [j for j in self.queues[queue_idx] if j in current_jobs]
        self.job_queue_mapping = {j: q for j, q in self.job_queue_mapping.items() if j in current_jobs}
        self.job_usage_stats = {j: u for j, u in self.job_usage_stats.items() if j in current_jobs}
        
        # First try to assign executor to the same job if possible
        if source_job is not None:
            # immediately schedulable nodes
            for node in source_job.frontier_nodes:
                if node in frontier_nodes:
                    return node, num_source_exec
            # schedulable node in the job
            for node in frontier_nodes:
                if node.job_dag == source_job:
                    return node, num_source_exec
        
        # Process queues in priority order (highest priority first)
        for queue_idx in range(self.num_queues):
            if not self.queues[queue_idx]:
                continue
                
            # Get first job in this queue
            current_job = self.queues[queue_idx][0]
            
            # Calculate time quantum for this queue level
            time_quantum = self.time_quantum_base * (2 ** queue_idx)
            
            # Find a schedulable node for this job
            next_node = None
            for node in current_job.frontier_nodes:
                if node in frontier_nodes:
                    next_node = node
                    break
            if next_node is None:
                for node in frontier_nodes:
                    if node in current_job.nodes:
                        next_node = node
                        break
                        
            if next_node is not None:
                # Calculate how many executors to assign
                use_exec = min(
                    node.num_tasks - node.next_task_idx - 
                    exec_commit.node_commit[node] - 
                    moving_executors.count(node),
                    time_quantum,
                    num_source_exec)
                
                # Update usage stats and possibly demote job
                self.job_usage_stats[current_job] += use_exec
                if self.job_usage_stats[current_job] >= time_quantum:
                    # Move job to lower priority queue if not already at lowest
                    if queue_idx < self.num_queues - 1:
                        self.queues[queue_idx].remove(current_job)
                        self.queues[queue_idx + 1].append(current_job)
                        self.job_queue_mapping[current_job] = queue_idx + 1
                    # Reset usage stats
                    self.job_usage_stats[current_job] = 0
                else:
                    # Move this job to end of its queue (round robin within queue)
                    self.queues[queue_idx].remove(current_job)
                    self.queues[queue_idx].append(current_job)
                
                return node, use_exec
        
        # If all queues are empty, there are more executors than tasks
        return None, num_source_exec