import numpy as np

def planning_function(processed_global_state):
    """
    Determines optimal tasks for each agent based on the current battle state.
    
    Args:
        processed_global_state: A tuple containing (available_move_actions, enemy_info, ally_info, own_info)

    Returns:
        dict: Optimal tasks for each agent ('move', 'attack', or 'none' for dead agents)
    """
    available_move_actions, enemy_info, ally_info, own_info = processed_global_state
    llm_tasks = {}
    
    # Identify the weakest enemy
    weakest_enemy = None
    min_health = float('inf')
    for agent_id, enemies in enemy_info.items():
        for enemy_id, enemy_data in enemies.items():
            if enemy_data[5] < min_health and enemy_data[5] > 0:  # enemy health > 0
                min_health = enemy_data[5]
                weakest_enemy = enemy_id

    for agent_id in own_info.keys():
        if not own_info[agent_id][4]:  # If agent is dead
            llm_tasks[agent_id] = 'none'
            continue

        agent_health = own_info[agent_id][0]
        enemies = enemy_info[agent_id]
        
        # Check if any enemy is in attack range
        enemy_in_range = any(enemy_data[0] for enemy_data in enemies.values())
        
        if enemy_in_range:
            # Attack the weakest enemy if in range, otherwise attack the closest
            target_enemy = weakest_enemy if enemies[weakest_enemy][0] else min(enemies, key=lambda e: enemies[e][1])
            llm_tasks[agent_id] = 'attack'
        else:
            # Move towards the closest enemy
            closest_enemy = min(enemies, key=lambda e: enemies[e][1])
            llm_tasks[agent_id] = 'move'
        
        # If agent health is low, consider moving to a safer position
        if agent_health < 0.3:
            llm_tasks[agent_id] = 'move'

    return llm_tasks

def compute_reward(processed_global_state, llm_tasks, tasks):
    """
    Calculate rewards based on the tasks assigned and their outcomes.
    
    Args:
        processed_global_state: returned from function process_global_state(global_state, n, m)
        llm_tasks (dict): Dictionary containing tasks assigned to each agent.
        tasks: A dictionary of task that actually perform by each agent
        
    Returns:
        reward: Dict containing rewards for each agent
    """
    available_move_actions, enemy_info, ally_info, own_info = processed_global_state
    reward = {}
    
    for agent_id in own_info.keys():
        agent_reward = 0
        
        if not own_info[agent_id][4]:  # If agent is dead
            reward[agent_id] = 0
            continue
        
        # Reward for following LLM task
        if llm_tasks[agent_id] == tasks[agent_id]:
            agent_reward += 0.5
        
        # Reward for damaging enemies
        if tasks[agent_id] == 'attack':
            for enemy_id, enemy_data in enemy_info[agent_id].items():
                if enemy_data[0]:  # If enemy is in attack range
                    agent_reward += 1 - enemy_data[5]  # Reward based on enemy's lost health
        
        # Penalty for taking damage
        agent_health = own_info[agent_id][0]
        if agent_health < 1:
            agent_reward -= (1 - agent_health)
        
        # Reward for good positioning
        if tasks[agent_id] == 'move':
            closest_enemy_dist = min(enemy_data[1] for enemy_data in enemy_info[agent_id].values())
            if 6 <= closest_enemy_dist <= 9:  # Within sight range but outside attack range
                agent_reward += 0.5
        
        # Reward for team coordination
        if all(tasks[ally_id] == 'attack' for ally_id in ally_info[agent_id].keys()):
            agent_reward += 0.5
        
        reward[agent_id] = agent_reward
    
    return reward