from harl.common.skills.smacv2.atomic_actions.combat import attack
from harl.common.skills.smacv2.atomic_actions.move import move_north, move_south, move_east, move_west
from harl.common.skills.smacv2.atomic_actions.basic import stop
from harl.common.skills.smacv2.atomic_actions.heal import heal
from harl.common.skills.smacv2.composite_skills import default_tactic as default_action, find_path
from harl.common.skills.smacv2.skill_registry import register_skill
from harl.utils.skill_utils import parse_obs
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Set
import random

@register_skill("race_zerg_baneling_navi_A_star_score_cluster_default_center")
def race_zerg_baneling_navi_A_star_score_cluster_default_center(obs: str):
    """
    Specialized Baneling Control Script
    
    Control Logic:
        - Prioritizes groups of enemies for maximum splash damage
        - Uses predictive positioning to intercept moving targets
        - Maintains safe distance until optimal engagement opportunity
        - Coordinates with other banelings for wave attacks
        
    Scoring System:
        - Heavy weight on enemy unit clustering
        - Considers target unit type vulnerability to splash
        - Factors in target movement patterns
        - Values strategic detonation timing

    Args:
        obs (str): Observation string containing game state
    """
    import math
    # Parse observation
    obs_data = parse_obs(obs)
    # Get set of available actions
    valid_actions = obs_data.available_actions
    
    if 0 in valid_actions:
        return 0

    def score_target(unit):
        """Enhanced target scoring with advanced combat tactics and positioning"""
        if unit.health <= 0:
            return -1
                    
        score = 0
        
        # Enhanced unit type priorities with situational adjustments
        unit_priorities = {
            'hydralisk': 30.0,  # High priority for their sustained DPS
            'zergling': 35.0,   # Medium priority as swarm units
            'baneling': 45.0,   # Critical priority due to splash damage
        }
        
        # Sophisticated matchup priorities with tactical considerations
        unit_counters = {
            'baneling': {'hydralisk': 1.2, 'zergling': 1.5, 'baneling': 1.0},
        }
        
        base_priority = unit_priorities.get(unit.unit_type.lower(), 5.0)
        is_ranged = unit.unit_type.lower() not in ['zealot', 'zergling', 'baneling']
        own_is_ranged = obs_data.own_unit_type.lower() not in ['zealot', 'zergling', 'baneling']
        
        # Enhanced threat assessment with tactical positioning
        matchup_mult = unit_counters.get(obs_data.own_unit_type.lower(), {}).get(unit.unit_type.lower(), 1.0)
        base_priority *= matchup_mult
        
        if hasattr(unit, 'can_attack'):  # Enemy unit
            score = base_priority

            # Dynamic distance scoring based on unit types
            optimal_range = 0.8 if is_ranged else 0.6
            distance_factor = 2.0 / (1 + abs(unit.distance - optimal_range))
            score *= distance_factor

            # Shield consideration for enhanced targeting
            if unit.shield is not None:
                shield_factor = 1.2 if unit.shield > 0.5 else 1.0
                score *= shield_factor

            # Advanced ranged unit support logic
            if not own_is_ranged:
                range_ally = [ally for ally in obs_data.allies if 
                            ally.unit_type.lower() not in ['zealot', 'zergling', 'baneling']]
                if range_ally:
                    ally_center = tuple(sum(coord) / len(range_ally) 
                                    for coord in zip(*[ally.position for ally in range_ally]))
                    ally_distance = ((ally_center[0] - unit.position[0])**2 + 
                                (ally_center[1] - unit.position[1])**2)**0.5
                    support_factor = 2.0 / (1 + ally_distance)
                    score *= support_factor
            
            # Enhanced Position Analysis with Dynamic Radius
            def calculate_combat_power(units, base_radius=0.3):
                total_power = 0
                positions = []
                unit_count = {'melee': 0, 'ranged': 0}
                
                # Dynamic radius based on combat intensity
                radius = base_radius * (1 + len(units) * 0.1)
                
                for u in units:
                    dist = ((u.position[0] - unit.position[0])**2 + 
                        (u.position[1] - unit.position[1])**2)**0.5
                    positions.append(u.position)
                    
                    if dist <= radius:
                        unit_type = 'melee' if u.unit_type.lower() in ['zealot', 'zergling', 'baneling'] else 'ranged'
                        unit_count[unit_type] += 1
                        
                        # Enhanced power calculation
                        base_power = unit_priorities.get(u.unit_type.lower(), 5.0)
                        
                        # Advanced health and shield consideration
                        health_shield_factor = 1.0
                        if u.health > 0.7:
                            health_shield_factor = 1.5
                        elif u.health > 0.4:
                            health_shield_factor = 1.2
                        if u.shield and u.shield > 0.5:
                            health_shield_factor *= 1.2
                        
                        # Position-based power scaling
                        position_factor = 1.5 - (dist/radius)
                        
                        total_power += base_power * health_shield_factor * position_factor
                
                # Enhanced formation analysis
                formation_bonus = 0
                if len(positions) >= 2:
                    center = tuple(sum(coord) / len(positions) for coord in zip(*positions))
                    avg_dist = sum(((p[0] - center[0])**2 + (p[1] - center[1])**2)**0.5 
                                for p in positions) / len(positions)
                    formation_bonus = 2.0 / (1 + avg_dist * 2)
                
                return total_power * (1 + formation_bonus), unit_count
            
            ally_power, ally_counts = calculate_combat_power(obs_data.allies)
            enemy_power, enemy_counts = calculate_combat_power(obs_data.enemies)
            
            # Enhanced focus fire with dynamic commitment
            num_attackers = sum(1 for ally in obs_data.allies 
                            if ally.last_action >= 6 and ally.last_action - 6 == unit.id)
            
            # Improved target persistence
            if unit.id == obs_data.last_action - 6:
                persistence_bonus = 2.5 if unit.health < 0.3 else 2.0
                score *= persistence_bonus
            
            # Dynamic focus fire scaling
            if num_attackers > 0:
                focus_bonus = min(1.3 ** num_attackers, 2.5)
                if num_attackers >= 3 and unit.id != obs_data.last_action - 6:
                    focus_bonus = 0.4
                score *= focus_bonus
            
            # Advanced combat advantage calculation
            advantage_factor = 1.0
            power_ratio = ally_power / max(enemy_power, 0.1)
            
            if power_ratio > 1.3:
                advantage_factor = 1.3
                if ally_counts['melee'] >= 3:
                    advantage_factor *= 1.3
            
            # Enhanced isolation targeting
            total_enemies = enemy_counts['melee'] + enemy_counts['ranged']
            if total_enemies == 1:
                advantage_factor *= 2.5
            elif (ally_counts['melee'] + ally_counts['ranged']) > total_enemies:
                advantage_factor *= 1.3

            score *= advantage_factor
            
            # Advanced health-based priority
            health_factor = 2.0 / (1 + unit.health)
            score *= health_factor
            
        else:  # Ally unit
            score = base_priority
            
            # Enhanced support priority with shield consideration
            health_factor = 2.0 / (1 + unit.health)
            if unit.shield is not None:
                shield_factor = 1.2 if unit.shield < 0.3 else 1.0
                health_factor *= shield_factor
            score *= health_factor
            
            # Improved distance-based priority
            distance_factor = 2.0 / (1 + unit.distance)
            score *= distance_factor
            
        return score

    def control_logic():
        # If there are enemies
        if obs_data.enemies:
            enemy_in_range = [enemy for enemy in obs_data.enemies if enemy.distance < 1]
            attack_actions = [a for a in valid_actions if a >= 6]

            if len(enemy_in_range) > 0:
                enemy_center = (sum(e.position[0] for e in enemy_in_range) / len(enemy_in_range),
                            sum(e.position[1] for e in enemy_in_range) / len(enemy_in_range))
            else:
                enemy_center = (sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies),
                            sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies))
            if obs_data.allies:
                # Check if any melee ally, if so and last action is not attack, move to center of melee allies
                melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                if melee_ally:
                    # Move to center of melee allies
                    ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally) / 2
                    ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally) / 2

                    safe_x = ally_x 
                    safe_y = ally_y

                    distance = (safe_x ** 2 + safe_y ** 2) ** 0.5
                    criterion = 2/obs_data.own_sight_range
                    if len(attack_actions) == 0 or distance > 0.5:
                        target_angle = math.atan2(enemy_center[1], enemy_center[0])
                        safe_angle = math.atan2(ally_y, ally_x)
                        angle_diff = abs(target_angle - safe_angle)
                        if distance > criterion and (math.pi/9 < angle_diff < 17*math.pi/9 or distance > 0.5):
                            path_action = find_path(obs_data, safe_x, safe_y)
                            if path_action:
                                return path_action
                        
            # Enhanced cluster detection with dynamic radius
            enemy_clusters = {}
            cluster_centers = {}
            for enemy in obs_data.enemies:
                nearby_enemies = []
                center_x, center_y = enemy.position[0], enemy.position[1]
                
                # Dynamic cluster radius based on unit type
                cluster_radius = 0.3 if obs_data.own_unit_type.lower() == 'baneling' else 0.2
                
                for other in obs_data.enemies:
                    distance = ((other.position[0] - enemy.position[0])**2 + 
                            (other.position[1] - enemy.position[1])**2)**0.5
                    if distance <= cluster_radius:
                        nearby_enemies.append(other)
                        center_x += other.position[0]
                        center_y += other.position[1]
                
                if nearby_enemies:
                    center_x /= len(nearby_enemies)
                    center_y /= len(nearby_enemies)
                    
                enemy_clusters[enemy.id] = len(nearby_enemies)
                cluster_centers[enemy.id] = (center_x, center_y)

            # Enhanced target scoring with tactical considerations
            target_scores = {}
            for enemy in obs_data.enemies:
                base_score = score_target(enemy)
                
                # Enhanced cluster bonus for splash damage
                if obs_data.own_unit_type.lower() == 'baneling':
                    cluster_bonus = 1.5 ** enemy_clusters[enemy.id]
                else:
                    cluster_bonus = 1.2 ** enemy_clusters[enemy.id]
                # Calculate final score with all factors
                target_scores[enemy.id] = (base_score + cluster_bonus)

            # Check if there are same max score targets
            max_score = max(target_scores.values())
            max_score_targets = [enemy for enemy in obs_data.enemies 
                        if target_scores[enemy.id] >= max_score]  # Allow for close scores
            if len(max_score_targets) > 1:
                # Choose target balancing distance and cluster potential
                best_target = min(max_score_targets, 
                        key=lambda x: x.distance)
            else:
                best_target = max_score_targets[0]

            if best_target.can_attack:
                return attack(best_target.id)
            else:
                # Move to the target
                dx = best_target.position[0]
                dy = best_target.position[1]
                path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                if path_action:
                    return path_action
                elif attack_actions:
                    attackable_enemies = [enemy for enemy in obs_data.enemies if enemy.can_attack]
                    if obs_data.last_action in attack_actions:
                        return obs_data.last_action
                    if attackable_enemies:
                        return attack(min(attackable_enemies, key=lambda e: e.distance).id)
                    return random.choice(attack_actions)

        # If there are no enemies
        else:
            # If there are only allies
            if obs_data.allies:
                # Improved melee group formation
                melee_allies = [ally for ally in obs_data.allies 
                            if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                if melee_allies:
                    spacing = 0.1 if obs_data.own_unit_type.lower() == 'baneling' else 0.05
                    # Dynamic group positioning
                    center_x = sum(ally.position[0] for ally in melee_allies) / len(melee_allies)
                    center_y = sum(ally.position[1] for ally in melee_allies) / len(melee_allies)
                    
                    # Calculate spread from center
                    max_spread = max(((ally.position[0] - center_x)**2 + 
                                    (ally.position[1] - center_y)**2)**0.5 
                                for ally in melee_allies)
                    
                    own_distance = ((center_x)**2 + (center_y)**2)**0.5
                    
                    if own_distance > spacing or max_spread > 0.1:
                        # Move toward center while maintaining minimum spacing
                        adjusted_x = center_x * 0.85  # Slight offset to prevent overcrowding
                        adjusted_y = center_y * 0.85
                        path_action = find_path(obs_data, adjusted_x, adjusted_y)
                        if path_action:
                            return path_action
                else:
                    ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies)
                    ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies)
                    distance = (ally_x ** 2 + ally_y ** 2) ** 0.5
                    if distance > 0.05:
                        dx = ally_x
                        dy = ally_y
                        path_action = find_path(obs_data, dx, dy)
                        if path_action:
                            return path_action
        
        return default_action(obs)
    return control_logic()

__all__ = [
    "race_zerg_baneling_navi_A_star_score_cluster_default_center",
]

if __name__ == "__main__":
    from harl.test import test_params
    # test each function in this py file
    for skill in __all__:
        print(f"Testing {skill}...")
        for text_id, test_param in enumerate(test_params):
            print(f"Test param: {text_id}")
            print(f"Result: {globals()[skill](**test_param)}")
    # for test_id, test_param in enumerate(test_params):
    # print(race_medivac_melee_ranged_navi_A_star_score_type_default_center(**test_params[-1]))