import numpy as np

def planning_function(processed_global_state):
    """
    Determines optimal tasks for each agent based on the current battle state.
    
    Args:
        processed_global_state: A tuple containing (move_feats, enemy_info, ally_info, own_info)

    Returns:
        dict: Optimal tasks for each agent ('move', 'attack', or 'none' for dead agents)
    """
    move_feats, enemy_info, ally_info, own_info = processed_global_state
    llm_tasks = {}
    
    for agent_id in own_info.keys():
        if not own_info[agent_id][5]:  # If agent is dead
            llm_tasks[agent_id] = 'none'
            continue
        
        enemy_data = enemy_info[agent_id]['enemy_0']
        dist_to_enemy = enemy_data[1][0]
        enemy_health = enemy_data[5][0]
        
        if dist_to_enemy > 9:  # If enemy not in sight range
            llm_tasks[agent_id] = 'move'  # Scout
        elif 6 < dist_to_enemy <= 9:  # If enemy in sight but not in attack range
            llm_tasks[agent_id] = 'move'  # Position
        elif dist_to_enemy <= 6:  # If enemy in attack range
            if enemy_health > 0.25:  # If enemy health is above 25%
                # Implement kiting: one agent attacks, one moves
                if agent_id == 'agent_0':
                    llm_tasks[agent_id] = 'attack'
                else:
                    llm_tasks[agent_id] = 'move'
            else:  # If enemy health is low, both attack
                llm_tasks[agent_id] = 'attack'
    
    return llm_tasks

def compute_reward(processed_state, llm_tasks, tasks):
    """
    Calculate rewards based on the tasks assigned and their outcomes.
    
    Args:
        processed_state: returned from function process_global_state(global_state, n, m)
        llm_tasks (dict): Dictionary containing tasks assigned to each agent.
        tasks: A dictionary of task that actually perform by each agent
        
    Returns:
        reward: Dict containing rewards for each agent
    """
    move_feats, enemy_info, ally_info, own_info = processed_state
    reward = {}
    
    for agent_id in own_info.keys():
        agent_reward = 0
        
        if not own_info[agent_id][5]:  # If agent is dead
            reward[agent_id] = -10  # Penalty for being dead
            continue
        
        enemy_data = enemy_info[agent_id]['enemy_0']
        dist_to_enemy = enemy_data[1][0]
        enemy_health = enemy_data[5][0]
        
        # Reward for following LLM tasks
        if llm_tasks[agent_id] == tasks[agent_id]:
            agent_reward += 1
        
        # Reward for good positioning
        if 5.5 <= dist_to_enemy <= 6:
            agent_reward += 2
        
        # Reward for damaging enemy
        if tasks[agent_id] == 'attack' and enemy_health < 1:
            agent_reward += 5
        
        # Reward for staying alive
        agent_reward += 1
        
        # Penalty for being too close to the enemy
        if dist_to_enemy < 5:
            agent_reward -= 2
        
        reward[agent_id] = agent_reward
    
    return reward