
import numpy as np
from agent import Agent


class TetrisAgent(Agent):
    # Tetris scheduling algorithm - packs tasks considering multiple resources
    # scheduling complexity: O(num_nodes * num_executors)
    def __init__(self, resources=['cpu', 'memory', 'disk', 'network']):
        Agent.__init__(self)
        self.resources = resources
        
    def get_action(self, obs):
        # parse observation
        job_dags, source_job, num_source_exec, \
        frontier_nodes, executor_limits, \
        exec_commit, moving_executors, action_map = obs
        
        # First try to assign executor to the same job if possible
        if source_job is not None:
            # immediately schedulable nodes
            for node in source_job.frontier_nodes:
                if node in frontier_nodes:
                    return node, num_source_exec
            # schedulable node in the job
            for node in frontier_nodes:
                if node.job_dag == source_job:
                    return node, num_source_exec
        
        # Calculate resource usage scores for each node
        scores = {}
        for node in frontier_nodes:
            # Assume each node has resource_demands attribute with per-resource demands
            # For example: node.resource_demands = {'cpu': 2, 'memory': 4, 'disk': 1, 'network': 3}
            if hasattr(node, 'resource_demands'):
                # Calculate alignment score based on resource usage
                resource_utilization = sum(node.resource_demands.get(res, 0) 
                                         for res in self.resources)
                
                # Calculate resource balance (prefer nodes that use resources evenly)
                resource_balance = -np.std([node.resource_demands.get(res, 0) 
                                          for res in self.resources])
                
                # Calculate task density (prefer nodes with more tasks per executor)
                task_density = node.num_tasks / max(1, node.num_executors_requested)
                
                # Combine into overall score (higher is better)
                scores[node] = resource_utilization + resource_balance + task_density
            else:
                # If resource demands aren't available, fall back to task count
                scores[node] = node.num_tasks
        
        # Sort nodes by score (highest first)
        sorted_nodes = sorted(scores.keys(), key=lambda n: scores[n], reverse=True)
        
        # Allocate to highest-scoring node first
        for node in sorted_nodes:
            use_exec = min(
                node.num_tasks - node.next_task_idx - 
                exec_commit.node_commit[node] - 
                moving_executors.count(node),
                num_source_exec)
            if use_exec > 0:
                return node, use_exec
        
        # there is more executors than tasks in the system
        return None, num_source_exec