from harl.common.skills.smacv2.atomic_actions.combat import attack
from harl.common.skills.smacv2.atomic_actions.move import move_north, move_south, move_east, move_west
from harl.common.skills.smacv2.atomic_actions.basic import stop
from harl.common.skills.smacv2.atomic_actions.heal import heal
from harl.common.skills.smacv2.composite_skills import default_tactic as default_action, find_path
from harl.common.skills.smacv2.skill_registry import register_skill
from harl.utils.skill_utils import parse_obs
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Set
import random

@register_skill("race_terran_medivac_navi_A_star_score_priority_health_default_allies")
def race_terran_medivac_navi_A_star_score_priority_health_default_allies(obs: str):
    """
    Specialized Terran Medivac Control Logic:
    - Priority-based healing system that considers unit type importance and damage level
    - Dynamic positioning system to maintain optimal healing range while avoiding threats
    - Advanced grouping logic to stay with high-value units and other support units
    - Threat avoidance system with predictive enemy movement analysis
    
    Implementation Details:
    - Enhanced target scoring incorporating unit priority, health status, and positioning
    - A* pathfinding with threat-aware routing
    - Tactical retreat logic when facing overwhelming enemy forces

    Args:
        obs (str): Observation string containing game state
    """
    import math
    # Parse observation
    obs_data = parse_obs(obs)
    valid_actions = obs_data.available_actions

    if 0 in valid_actions:
        return 0

    def score_target(unit):
        """Enhanced target scoring for Medivac with improved healing priority and positioning"""
        if unit.health <= 0:
            return -1
                    
        score = 0
        
        # Refined unit healing priorities based on unit value and vulnerability
        unit_priorities = {
            'marine': 60.0,     # Priority due to low HP pool
            'marauder': 80.0,   # High priority due to high resource cost
            'medivac': 40.0,    # Lower priority for other support units
        }
        
        # Healing urgency thresholds
        healing_urgency = {
            'marine': 0.7,      # Marines need earlier healing due to fragility
            'marauder': 0.6,    # Marauders can sustain more damage
            'medivac': 0.5,     # Support units lowest priority
        }
        
        base_priority = unit_priorities.get(unit.unit_type.lower(), 5.0)
        
        if not hasattr(unit, 'can_attack'):  # Ally unit
            score = base_priority
            
            # Enhanced healing priority calculation
            unit_urgency = healing_urgency.get(unit.unit_type.lower(), 0.6)
            health_deficit = 1.0 - unit.health
            
            # Exponential scaling for units below their urgency threshold
            if unit.health < unit_urgency:
                health_factor = 2.0 * (1 + health_deficit) ** 2
            else:
                health_factor = 1 + health_deficit
            
            score *= health_factor
            
            # Distance consideration with safe zone adjustment
            optimal_range = 0.6  # Maintain moderate distance for flexibility
            distance_factor = 1.0
            
            if unit.distance > optimal_range:
                # Reduced score for units too far away
                distance_factor = max(0.5, 1.0 - (unit.distance - optimal_range))
            else:
                # Bonus for units at ideal range
                distance_factor = 1.2
            
            score *= distance_factor
            
            # Threat assessment for positioning
            nearby_enemies = [e for e in obs_data.enemies if e.distance < 1.0]
            if nearby_enemies:
                enemy_center_x = sum(e.position[0] for e in nearby_enemies) / len(nearby_enemies)
                enemy_center_y = sum(e.position[1] for e in nearby_enemies) / len(nearby_enemies)
                
                # Calculate if unit is in a safer position relative to enemies
                unit_to_enemy_dist = ((unit.position[0] - enemy_center_x)**2 + 
                                    (unit.position[1] - enemy_center_y)**2)**0.5
                
                safety_factor = min(1.5, unit_to_enemy_dist)
                score *= safety_factor
            
            # Bonus for units currently in combat
            if any(e.distance < 1.0 for e in obs_data.enemies):
                score *= 1.3
            
            # Persistence bonus for current healing target
            if unit.id == obs_data.last_action - 6:
                score *= 1.4
            
            # Group cohesion bonus
            nearby_allies = len([a for a in obs_data.allies if a.distance < 0.5])
            if nearby_allies >= 2:
                score *= 1.2
        
        return score

    def control_logic():
        attack_actions = [a for a in valid_actions if a >= 6]
        # If there are allies
        if obs_data.allies:
            lowest_health_ally = min(obs_data.allies, key=lambda x: x.health)
            # If there are both allies and enemies
            if obs_data.enemies:
                enemy_in_range = [enemy for enemy in obs_data.enemies if enemy.distance < 1]
                # Check if any melee ally, if so and last action is not attack, move to center of melee allies
                melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                if melee_ally:
                    # Move to center of melee allies
                    ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally)
                    ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally)
                else:
                    ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies)
                    ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies)
                # Calculate retreat position
                if len(enemy_in_range) > 0:
                    enemy_center = (sum(e.position[0] for e in enemy_in_range) / len(enemy_in_range),
                                sum(e.position[1] for e in enemy_in_range) / len(enemy_in_range))
                else:
                    enemy_center = (sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies),
                                sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies))
                
                dx = ally_x - enemy_center[0]
                dy = ally_y - enemy_center[1]

                distance = (dx ** 2 + dy ** 2) ** 0.5

                safe_x = ally_x + (dx / abs(dx)) * 2/obs_data.own_sight_range if dx != 0 else ally_x
                safe_y = ally_y + (dy / abs(dy)) * 2/obs_data.own_sight_range if dy != 0 else ally_y

                distance = (safe_x ** 2 + safe_y ** 2) ** 0.5
                criterion = 5/obs_data.own_sight_range
                if obs_data.last_action >= 6 or len(attack_actions) == 0:
                    target_angle = math.atan2(enemy_center[1], enemy_center[0])
                    safe_angle = math.atan2(ally_y, ally_x)
                    angle_diff = abs(target_angle - safe_angle)
                    if distance > criterion and (math.pi/9 < angle_diff < 17*math.pi/9):
                        path_action = find_path(obs_data, safe_x, safe_y)
                        if path_action:
                            return path_action
            target_scores = {ally.id: score_target(ally) for ally in obs_data.allies}
            # Check if there are same max score targets
            max_score = max(target_scores.values())
            max_score_target_ids = [target_id for target_id, score in target_scores.items() if score == max_score]
            max_score_targets = [ally for ally in obs_data.allies if ally.id in max_score_target_ids]
            closest_ally = min(obs_data.allies, key=lambda x: x.distance)
            if len(max_score_targets) > 1:
                # Chose the closest target
                best_target = min(max_score_targets, key=lambda x: x.distance)
            else:
                best_target = max_score_targets[0]
            if (best_target.id + 6) in valid_actions and 0 < best_target.health < 0.9:
                return heal(best_target.id)
            elif (closest_ally.id + 6) in valid_actions and 0 < closest_ally.health < 0.9:
                return heal(closest_ally.id)
            elif (lowest_health_ally.id + 6) in valid_actions and 0 < lowest_health_ally.health < 0.9:
                return heal(lowest_health_ally.id)
            else:
                # Move to the target
                dx = best_target.position[0]
                dy = best_target.position[1]
                path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                if path_action:
                    return path_action
        # If there are no allies
        else:
            # If there are only enemies
            if obs_data.enemies:
                enemy_x = sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies)
                enemy_y = sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies)
                target_x = - enemy_x
                target_y = - enemy_y

                g_x = target_x * obs_data.own_sight_range + (obs_data.own_position[0] * 32)
                g_y = target_y * obs_data.own_sight_range + (obs_data.own_position[1] * 32)
                if not (0 <= g_x <= 32 and 0 <= g_y <= 32):
                    target_x = (0.5 - obs_data.own_position[0]) * 32 / obs_data.own_sight_range
                    target_y = (0.5 - obs_data.own_position[1]) * 32 / obs_data.own_sight_range

                path_action = find_path(obs_data, target_x, target_y)
                if path_action:
                    return path_action
        
        # Default behavior
        return default_action(obs)

    return control_logic()

__all__ = [
    "race_terran_medivac_navi_A_star_score_priority_health_default_allies",
]

if __name__ == "__main__":
    from harl.test import test_params
    # test each function in this py file
    for skill in __all__:
        print(f"Testing {skill}...")
        for text_id, test_param in enumerate(test_params):
            print(f"Test param: {text_id}")
            print(f"Result: {globals()[skill](**test_param)}")
    # for test_id, test_param in enumerate(test_params):
    # print(race_medivac_melee_ranged_navi_A_star_score_type_default_center(**test_params[-1]))