from harl.common.skills.smacv2.atomic_actions.combat import attack
from harl.common.skills.smacv2.atomic_actions.move import move_north, move_south, move_east, move_west
from harl.common.skills.smacv2.atomic_actions.basic import stop
from harl.common.skills.smacv2.atomic_actions.heal import heal
from harl.common.skills.smacv2.composite_skills import default_tactic as default_action, find_path
from harl.common.skills.smacv2.skill_registry import register_skill
from harl.utils.skill_utils import parse_obs
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Set
import random

@register_skill("race_protoss_stalker_navi_A_star_score_strategic_default_cover")
def race_protoss_stalker_navi_A_star_score_strategic_default_cover(obs: str):
    """
    Stalker control script optimized for strategic ranged combat and shield management.
    
    Key strategies:
    1. Optimal range positioning with shield state awareness
    2. Focus fire coordination on priority targets
    3. Strategic retreats for shield regeneration
    4. Protected position seeking behind zealot frontline
    5. Coordinated attacks on armored targets
    
    The script emphasizes maintaining effective range and shield management
    while coordinating with other Protoss units.

    Args:
        obs (str): Observation string containing game state
    """
    # Parse observation
    obs_data = parse_obs(obs)
    valid_actions = obs_data.available_actions

    if 0 in valid_actions:
        return 0

    import math
    # Parse observation
    obs_data = parse_obs(obs)
    valid_actions = obs_data.available_actions

    if 0 in valid_actions:
        return 0

    def score_target(unit):
        """Enhanced target scoring with shield management and improved positioning"""
        if unit.health <= 0:
            return -1
                    
        score = 0
        
        # Enhanced unit type priorities with shield consideration
        unit_priorities = {
            'colossus': 35.0,  # Further increased priority
            'stalker': 30.0,   # Enhanced anti-armor focus
            'zealot': 45.0,    # Higher melee threat recognition
        }
        
        # Refined matchup priorities with shield dynamics
        unit_counters = {
            'stalker': {
                'colossus': 1.2, 
                'stalker': 1.0,  # Increased for mirror matchups
                'zealot': 1.5,   # Higher priority when zealots approach
            }
        }
        
        base_priority = unit_priorities.get(unit.unit_type.lower(), 5.0)
        is_ranged = unit.unit_type.lower() not in ['zealot', 'zergling', 'baneling']
        own_is_ranged = obs_data.own_unit_type.lower() not in ['zealot', 'zergling', 'baneling']
        
        matchup_mult = unit_counters.get(obs_data.own_unit_type.lower(), {}).get(unit.unit_type.lower(), 1.0)
        base_priority *= matchup_mult
        
        if hasattr(unit, 'can_attack'):  # Enemy unit
            score = base_priority
            
            # Enhanced shield-based distance scoring
            shield_factor = 1.0
            if hasattr(obs_data, 'own_shield') and obs_data.own_shield is not None:
                shield_ratio = obs_data.own_shield / (obs_data.own_health + obs_data.own_shield)
                if shield_ratio < 0.3:  # Low shields
                    shield_factor = 0.7  # More cautious
                elif shield_ratio > 0.7:  # High shields
                    shield_factor = 1.3  # More aggressive
            
            # Improved distance consideration with shield state
            distance_factor = max((1.2 - unit.distance) * shield_factor, 0.5)
            score *= distance_factor
            
            # Enhanced position analysis with shield dynamics
            position_x, position_y = unit.position
            
            def calculate_combat_power(units, radius=0.25):
                total_power = 0
                ranged_count = 0
                melee_count = 0
                shield_sum = 0
                unit_positions = []
                
                for u in units:
                    dist = ((u.position[0] - position_x)**2 + 
                        (u.position[1] - position_y)**2)**0.5
                    unit_positions.append(u.position)
                    
                    if dist <= radius:
                        base_power = unit_priorities.get(u.unit_type.lower(), 5.0)
                        
                        # Shield state consideration
                        if hasattr(u, 'shield') and u.shield is not None:
                            shield_ratio = u.shield / (u.health + u.shield)
                            shield_sum += shield_ratio
                            base_power *= (0.8 + shield_ratio * 0.4)
                        
                        # Unit type power scaling
                        if u.unit_type.lower() in ['zealot', 'zergling', 'baneling']:
                            melee_count += 1
                            if melee_count >= 2:
                                base_power *= 1.5
                        else:
                            ranged_count += 1
                            if ranged_count >= 3:  # Stronger ranged group bonus
                                base_power *= 1.4
                        
                        health_factor = 1.6 if u.health > 0.8 else 1.2 if u.health > 0.5 else 0.8
                        position_factor = 1.4 - (dist/radius)
                        
                        total_power += base_power * health_factor * position_factor
                
                # Enhanced formation analysis
                cohesion = 0
                if len(unit_positions) > 2:
                    center_x = sum(p[0] for p in unit_positions) / len(unit_positions)
                    center_y = sum(p[1] for p in unit_positions) / len(unit_positions)
                    avg_dist = sum(((p[0] - center_x)**2 + (p[1] - center_y)**2)**0.5 
                                for p in unit_positions) / len(unit_positions)
                    max_desired_dist = 0.25
                    cohesion = 2.2 / (1.0 + (avg_dist / max_desired_dist))
                
                return total_power * (1 + cohesion), melee_count, ranged_count, shield_sum
            
            # Calculate combat metrics with shield consideration
            ally_power, ally_swarms, ally_ranged, ally_shields = calculate_combat_power(obs_data.allies)
            enemy_power, enemy_swarms, enemy_ranged, enemy_shields = calculate_combat_power(obs_data.enemies)
            
            # Enhanced focus fire coordination
            num_attackers = sum(1 for ally in obs_data.allies 
                            if ally.last_action >= 6 and ally.last_action - 6 == unit.id)
            
            if unit.id == obs_data.last_action - 6:
                persistence_bonus = 2.2
                score *= persistence_bonus
            
            if num_attackers > 0:
                focus_bonus = 1.3 ** min(num_attackers, 3)  # Capped scaling
                if num_attackers >= 4 and unit.id != obs_data.last_action - 6:
                    focus_bonus = 0.6  # Prevent excessive focus
                score *= focus_bonus
            
            # Improved combat advantage calculation
            advantage_factor = 1.0
            power_ratio = ally_power / max(enemy_power, 0.1)
            shield_advantage = (ally_shields / max(len(obs_data.allies), 1)) - (enemy_shields / max(len(obs_data.enemies), 1))
            
            if power_ratio > 1.4:
                advantage_factor = 1.3
                if shield_advantage > 0.2:  # Strong shield advantage
                    advantage_factor *= 1.2
            elif power_ratio < 0.7:  # Significant disadvantage
                advantage_factor = 0.7
            
            # Priority for isolated targets
            if (enemy_swarms + enemy_ranged) == 1:
                advantage_factor *= 2.2
            elif (ally_swarms + ally_ranged) > (enemy_swarms + enemy_ranged):
                advantage_factor *= 1.3
            
            score *= advantage_factor
            
            # Enhanced health and shield consideration
            health_factor = (1.2 - unit.health) + 1
            if hasattr(unit, 'shield') and unit.shield is not None:
                shield_ratio = unit.shield / (unit.health + unit.shield)
                health_factor *= (1.3 - shield_ratio * 0.3)  # Higher priority for low shield targets
            score *= health_factor
            
        else:  # Ally unit
            score = base_priority
            
            # Enhanced support priority with shield consideration
            health_factor = (1.2 - unit.health) + 1
            if hasattr(unit, 'shield') and unit.shield is not None:
                shield_ratio = unit.shield / (unit.health + unit.shield)
                health_factor *= (1.4 - shield_ratio * 0.4)  # Higher priority for low shield allies
            score *= health_factor
            
            distance_factor = max((1.2 - unit.distance) + 1, 0.5)
            score *= distance_factor
            
        return score

    def control_logic():
        """Execute area control and line attack tactics"""
        attack_actions = [a for a in valid_actions if a >= 6]
        # If there are enemies
        if obs_data.enemies:
            # If there are both allies and enemies
            # Calculate retreat position
            enemy_in_range = [enemy for enemy in obs_data.enemies if enemy.distance < 1]
            if len(enemy_in_range) > 0:
                enemy_center = (sum(e.position[0] for e in enemy_in_range) / len(enemy_in_range),
                            sum(e.position[1] for e in enemy_in_range) / len(enemy_in_range))
            else:
                enemy_center = (sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies),
                            sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies))
            if obs_data.allies:
                # Check if any melee ally, if so and last action is not attack, move to center of melee allies
                melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                melee_enemy = [enemy for enemy in enemy_in_range if enemy.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                if melee_ally:
                    # Move to center of melee allies
                    ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally)
                    ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally)
                    
                else:
                    ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies) / 2
                    ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies) / 2

                dx = ally_x - enemy_center[0]
                dy = ally_y - enemy_center[1]
                safe_x = ally_x
                safe_y = ally_y
                if melee_ally:
                    safe_x = safe_x + (dx / abs(dx)) * 2/obs_data.own_sight_range if dx != 0 else safe_x
                    safe_y = safe_y + (dy / abs(dy)) * 2/obs_data.own_sight_range if dy != 0 else safe_y
                melee_threaten = False
                if melee_enemy:
                    closest_melee_enemy = min(melee_enemy, key=lambda x: x.distance)
                    if closest_melee_enemy.distance <= 4/obs_data.own_sight_range:
                        melee_threaten = True
                        dx = safe_x - closest_melee_enemy.position[0]
                        dy = safe_y - closest_melee_enemy.position[1]
                        safe_x = safe_x + (dx / abs(dx)) * 1/obs_data.own_sight_range if dx != 0 else safe_x
                        safe_y = safe_y + (dy / abs(dy)) * 1/obs_data.own_sight_range if dy != 0 else safe_y
                
                distance = (safe_x ** 2 + safe_y ** 2) ** 0.5
                criterion = 4/obs_data.own_sight_range
                if obs_data.last_action >= 6 or len(attack_actions) == 0 or distance > 0.9:
                    target_angle = math.atan2(enemy_center[1], enemy_center[0])
                    safe_angle = math.atan2(ally_y, ally_x)
                    angle_diff = abs(target_angle - safe_angle)
                    if distance > criterion and ((math.pi/9 < angle_diff < 17*math.pi/9) or melee_threaten or distance > 0.9):
                        path_action = find_path(obs_data, safe_x, safe_y)
                        if path_action:
                            return path_action
                # Focus fire logic
                # Count how many allies are attacking each enemy
                target_counts = {}
                for ally in obs_data.allies:
                    if ally.last_action >= 6:
                        target_id = ally.last_action - 6
                        target_counts[target_id] = target_counts.get(target_id, 0) + 1
                # Distance to safe point of each enemy affect target choosing
                enemy_safe_distance = {enemy.id: ((enemy.position[0] - safe_x) ** 2 + (enemy.position[1] - safe_y) ** 2) ** 0.5 for enemy in obs_data.enemies}
                # Find best target combining focus fire and threat scoring
                target_scores = {enemy.id: score_target(enemy) for enemy in obs_data.enemies}
                for target_id, count in target_counts.items():
                    if target_id in target_scores:
                        target_scores[target_id] += count * 0.5
                for target_id, scores in target_scores.items():
                    target_scores[target_id] = scores * (1 - enemy_safe_distance[target_id] * 0.3)
                best_target_id = max(target_scores.items(), key=lambda x: x[1])[0]
                best_target = next(enemy for enemy in obs_data.enemies if enemy.id == best_target_id)
                if best_target.can_attack:
                    return attack(best_target_id)
                else:
                    # Best target is not in shoot range, move to target
                    dx = best_target.position[0]
                    dy = best_target.position[1]

                    # Only move to target if its direction is not conflicting with the safe point
                    # Check if target direction aligns with safe point direction
                    target_angle = math.atan2(dy, dx)
                    safe_angle = math.atan2(ally_y, ally_x)
                    angle_diff = abs(target_angle - safe_angle)
                    # Only move if angle difference is less than 90 degrees
                    if angle_diff < math.pi/9 or angle_diff > 17*math.pi/9 or not melee_ally:
                        if best_target.distance > obs_data.own_shoot_range / obs_data.own_sight_range:
                            path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                            if path_action:
                                return path_action
                    if attack_actions:
                        if obs_data.last_action in attack_actions:
                            return obs_data.last_action
                        attackable_enemies = [enemy for enemy in obs_data.enemies if enemy.can_attack]
                        closest_enemy = min(
                            [enemy for enemy in attackable_enemies],
                            key=lambda enemy: enemy.distance,
                            default=None,
                        )
                        if closest_enemy and closest_enemy.can_attack:
                            return attack(closest_enemy.id)
                        return random.choice(attack_actions)
                    else:
                        if distance > criterion:
                            path_action = find_path(obs_data, safe_x, safe_y)
                            if path_action:
                                return path_action
                    
            # If there are only enemies
            else:
                # Closest enemy as target
                closest_enemy = min(
                    [enemy for enemy in obs_data.enemies],
                    key=lambda enemy: enemy.distance,
                    default=None,
                )
                # No allies, kitting melee enemies
                if closest_enemy.unit_type.lower() in ['zealot', 'zergling', 'baneling']:
                    if closest_enemy.distance <= 4 / obs_data.own_sight_range and obs_data.last_action >= 6:
                        enemy_x = sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies)
                        enemy_y = sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies)
                        target_x = - enemy_x
                        target_y = - enemy_y

                        g_x = target_x * obs_data.own_sight_range + (obs_data.own_position[0] * 32)
                        g_y = target_y * obs_data.own_sight_range + (obs_data.own_position[1] * 32)
                        if not (0 <= g_x <= 32 and 0 <= g_y <= 32):
                            target_x = (0.5 - obs_data.own_position[0]) * 32 / obs_data.own_sight_range
                            target_y = (0.5 - obs_data.own_position[1]) * 32 / obs_data.own_sight_range

                        path_action = find_path(obs_data, target_x, target_y)
                        if path_action:
                            return path_action
                    if closest_enemy.can_attack:
                        return attack(closest_enemy.id)
                else:
                    # No melee enemies, highest priority enemy as target
                    target_scores = {enemy.id: score_target(enemy) for enemy in obs_data.enemies}
                    # Check if there are same max score targets
                    max_score = max(target_scores.values())
                    max_score_target_ids = [target_id for target_id, score in target_scores.items() if score == max_score]
                    max_score_targets = [enemy for enemy in obs_data.enemies if enemy.id in max_score_target_ids]
                    if len(max_score_targets) > 1:
                        # Chose the closest target
                        best_target = min(max_score_targets, key=lambda x: x.distance)
                    else:
                        best_target = max_score_targets[0]
                    if best_target.can_attack:
                        return attack(best_target.id)
                    else:
                        # Best target is not in shoot range, move to target
                        dx = best_target.position[0]
                        dy = best_target.position[1]
                        if best_target.distance > obs_data.own_shoot_range / obs_data.own_sight_range:
                            path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                            if path_action:
                                return path_action
                            elif attack_actions:
                                if obs_data.last_action in attack_actions:
                                    return obs_data.last_action
                                attackable_enemies = [enemy for enemy in obs_data.enemies if enemy.can_attack]
                                closest_enemy = min(
                                    [enemy for enemy in attackable_enemies],
                                    key=lambda enemy: enemy.distance,
                                    default=None,
                                )
                                if closest_enemy and closest_enemy.can_attack:
                                    return attack(closest_enemy.id)
                                return random.choice(attack_actions)
        # If there are no enemies
        else:
            # If there are only allies
            if obs_data.allies:
                # Check if any melee ally
                melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                if melee_ally:
                    # Move to center of melee allies
                    melee_ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally)
                    melee_ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally)
                    dx = melee_ally_x
                    dy = melee_ally_y
                    distance = (dx ** 2 + dy ** 2) ** 0.5
                    if distance > 0.05:
                        path_action = find_path(obs_data, dx, dy)
                        if path_action:
                            return path_action
                else:
                    # No melee allies, move to target ally
                    ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies)
                    ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies)
                    distance = (ally_x ** 2 + ally_y ** 2) ** 0.5
                    if distance > 0.05:
                        dx = ally_x
                        dy = ally_y
                        path_action = find_path(obs_data, dx, dy)
                        if path_action:
                            return path_action
        return default_action(obs)
    return control_logic()

__all__ = [
    "race_protoss_stalker_navi_A_star_score_strategic_default_cover",
]

if __name__ == "__main__":
    from harl.test import test_params
    # test each function in this py file
    for skill in __all__:
        print(f"Testing {skill}...")
        for text_id, test_param in enumerate(test_params):
            print(f"Test param: {text_id}")
            print(f"Result: {globals()[skill](**test_param)}")
    # for test_id, test_param in enumerate(test_params):
    # print(race_medivac_melee_ranged_navi_A_star_score_type_default_center(**test_params[-1]))