from harl.common.skills.smacv2.atomic_actions.combat import attack
from harl.common.skills.smacv2.atomic_actions.move import move_north, move_south, move_east, move_west
from harl.common.skills.smacv2.atomic_actions.basic import stop
from harl.common.skills.smacv2.atomic_actions.heal import heal
from harl.common.skills.smacv2.composite_skills import default_tactic as default_action, find_path
from harl.common.skills.smacv2.skill_registry import register_skill
from harl.utils.skill_utils import parse_obs
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Set
import random

@register_skill("race_protoss_colossus_navi_A_star_score_area_control_default_protected")
def race_protoss_colossus_navi_A_star_score_area_control_default_protected(obs: str):
    """
    Colossus control script optimized for maximum area damage impact.
    
    Key strategies:
    1. Protected positioning behind allied units
    2. Optimal line attack positioning for maximum splash damage
    3. Shield-based risk assessment and positioning
    4. Area denial through threat projection
    5. Coordinated movement with escort units
    
    The script emphasizes the Colossus's role as a critical area control unit
    while maintaining protection due to its high strategic value.

    Args:
        obs (str): Observation string containing game state
    """
    import math
    # Parse observation
    obs_data = parse_obs(obs)
    valid_actions = obs_data.available_actions

    if 0 in valid_actions:
        return 0

    def score_target(unit):
        """Enhanced target scoring optimized for Colossus splash damage and tactical positioning"""
        if unit.health <= 0:
            return -1
                
        score = 0
        
        # Refined unit type priorities with splash damage consideration
        unit_priorities = {
            'colossus': 35.0,  # Further increased priority
            'stalker': 30.0,   # Enhanced anti-armor focus
            'zealot': 45.0,    # Higher melee threat recognition
        }
        
        # Dynamic matchup priorities with enhanced splash damage consideration
        unit_counters = {
            'colossus': {'colossus': 2.8, 'stalker': 2.2, 'zealot': 3.8, 'marine': 3.5, 'zergling': 3.2},
            'stalker': {'colossus': 1.5, 'stalker': 2.0, 'zealot': 2.8, 'marine': 2.5, 'zergling': 2.0},
            'zealot': {'colossus': 1.0, 'stalker': 1.8, 'zealot': 1.5, 'marine': 2.0, 'zergling': 1.5},
        }
        
        base_priority = unit_priorities.get(unit.unit_type.lower(), 5.0)
        is_ranged = unit.unit_type.lower() not in ['zealot', 'zergling', 'baneling']
        own_is_ranged = obs_data.own_unit_type.lower() not in ['zealot', 'zergling', 'baneling']
        
        # Enhanced splash damage potential assessment
        def calculate_splash_potential(target_pos, radius=0.25):
            splash_score = 0
            center_x, center_y = target_pos
            
            for enemy in obs_data.enemies:
                if enemy.id != unit.id:  # Don't count target unit itself
                    dist = ((enemy.position[0] - center_x)**2 + 
                        (enemy.position[1] - center_y)**2)**0.5
                    if dist <= radius:
                        # Weight by unit value and distance from center
                        unit_value = unit_priorities.get(enemy.unit_type.lower(), 5.0)
                        distance_weight = 1 - (dist/radius)**0.5  # Non-linear falloff
                        splash_score += unit_value * distance_weight * enemy.health
                        
                        # Bonus for clustered light units
                        if enemy.unit_type.lower() in ['marine', 'zergling', 'zealot']:
                            splash_score *= 1.2
            
            return splash_score

        if hasattr(unit, 'can_attack'):  # Enemy unit
            score = base_priority
            
            # Enhanced distance scoring with Colossus range consideration
            optimal_range = obs_data.own_shoot_range * 0.8 / obs_data.own_sight_range
            distance_factor = 1.0
            if unit.distance <= optimal_range:
                distance_factor = 1.2  # Bonus for targets at optimal range
            else:
                distance_factor = max((1.5 - unit.distance), 0.5)
            score *= distance_factor

            # Calculate splash damage potential
            splash_bonus = calculate_splash_potential(unit.position)
            score += splash_bonus * 0.8  # Weight splash damage contribution
            
            # Enhanced position analysis for line attacks
            def calculate_line_attack_value():
                line_score = 0
                unit_vector = (unit.position[0]/unit.distance, unit.position[1]/unit.distance)
                
                for enemy in obs_data.enemies:
                    if enemy.id != unit.id:
                        # Project enemy position onto attack line
                        dot_product = (enemy.position[0] * unit_vector[0] + 
                                    enemy.position[1] * unit_vector[1])
                        projection = (dot_product * unit_vector[0], dot_product * unit_vector[1])
                        
                        # Calculate perpendicular distance to line
                        perp_dist = ((enemy.position[0] - projection[0])**2 + 
                                (enemy.position[1] - projection[1])**2)**0.5
                        
                        if perp_dist < 0.2:  # Within thermal lance width
                            line_score += unit_priorities.get(enemy.unit_type.lower(), 5.0) * (1 - perp_dist/0.2)
                
                return line_score
            
            line_attack_value = calculate_line_attack_value()
            score += line_attack_value * 0.6  # Weight line attack potential
            
            # Shield consideration
            if unit.shield is not None and unit.shield > 0:
                shield_factor = 1.2 if unit.shield < 0.3 else 1.0  # Prioritize low shield targets
                score *= shield_factor
            
            # Focus fire coordination with improved commitment logic
            num_attackers = sum(1 for ally in obs_data.allies 
                            if ally.last_action >= 6 and ally.last_action - 6 == unit.id)
            
            if unit.id == obs_data.last_action - 6:
                persistence_bonus = 2.2  # Enhanced target commitment
                score *= persistence_bonus
            
            if num_attackers > 0:
                focus_bonus = min(1.3 ** num_attackers, 2.5)  # Capped focus fire bonus
                score *= focus_bonus
            
            # Health factor with shield consideration
            health_factor = 1.0
            if unit.shield is not None:
                total_health = unit.health + unit.shield
                health_factor = (2 - total_health)  # More aggressive scaling
            else:
                health_factor = (2 - unit.health)
            score *= health_factor
            
        else:  # Ally unit - simplified scoring for positioning reference
            score = base_priority * 0.5  # Reduced priority for ally positioning
            
            health_factor = (1.5 - unit.health) if unit.health < 0.5 else 1.0
            score *= health_factor
            
            distance_factor = max((1.2 - unit.distance), 0.6)
            score *= distance_factor
        
        return score

    def control_logic():
        """Execute area control and line attack tactics"""
        attack_actions = [a for a in valid_actions if a >= 6]
        # If there are enemies
        if obs_data.enemies:
            # If there are both allies and enemies
            # Calculate retreat position
            enemy_in_range = [enemy for enemy in obs_data.enemies if enemy.distance < 1]
            if len(enemy_in_range) > 0:
                enemy_center = (sum(e.position[0] for e in enemy_in_range) / len(enemy_in_range),
                            sum(e.position[1] for e in enemy_in_range) / len(enemy_in_range))
            else:
                enemy_center = (sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies),
                            sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies))
            if obs_data.allies:
                # Check if any melee ally, if so and last action is not attack, move to center of melee allies
                melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                melee_enemy = [enemy for enemy in enemy_in_range if enemy.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                if melee_ally:
                    # Move to center of melee allies
                    ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally)
                    ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally)
                    
                else:
                    ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies) / 2
                    ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies) / 2

                dx = ally_x - enemy_center[0]
                dy = ally_y - enemy_center[1]
                safe_x = ally_x
                safe_y = ally_y
                if melee_ally:
                    safe_x = safe_x + (dx / abs(dx)) * 2/obs_data.own_sight_range if dx != 0 else safe_x
                    safe_y = safe_y + (dy / abs(dy)) * 2/obs_data.own_sight_range if dy != 0 else safe_y
                melee_threaten = False
                if melee_enemy:
                    closest_melee_enemy = min(melee_enemy, key=lambda x: x.distance)
                    if closest_melee_enemy.distance <= 4/obs_data.own_sight_range:
                        melee_threaten = True
                        dx = safe_x - closest_melee_enemy.position[0]
                        dy = safe_y - closest_melee_enemy.position[1]
                        safe_x = safe_x + (dx / abs(dx)) * 1/obs_data.own_sight_range if dx != 0 else safe_x
                        safe_y = safe_y + (dy / abs(dy)) * 1/obs_data.own_sight_range if dy != 0 else safe_y
                
                distance = (safe_x ** 2 + safe_y ** 2) ** 0.5
                criterion = 4/obs_data.own_sight_range
                if obs_data.last_action >= 6 or len(attack_actions) == 0 or distance > 0.9:
                    target_angle = math.atan2(enemy_center[1], enemy_center[0])
                    safe_angle = math.atan2(ally_y, ally_x)
                    angle_diff = abs(target_angle - safe_angle)
                    if distance > criterion and ((math.pi/9 < angle_diff < 17*math.pi/9) or melee_threaten or distance > 0.9):
                        path_action = find_path(obs_data, safe_x, safe_y)
                        if path_action:
                            return path_action
                # Focus fire logic
                # Count how many allies are attacking each enemy
                target_counts = {}
                for ally in obs_data.allies:
                    if ally.last_action >= 6:
                        target_id = ally.last_action - 6
                        target_counts[target_id] = target_counts.get(target_id, 0) + 1
                # Distance to safe point of each enemy affect target choosing
                enemy_safe_distance = {enemy.id: ((enemy.position[0] - safe_x) ** 2 + (enemy.position[1] - safe_y) ** 2) ** 0.5 for enemy in obs_data.enemies}
                # Find best target combining focus fire and threat scoring
                target_scores = {enemy.id: score_target(enemy) for enemy in obs_data.enemies}
                for target_id, count in target_counts.items():
                    if target_id in target_scores:
                        target_scores[target_id] += count * 0.5
                for target_id, scores in target_scores.items():
                    target_scores[target_id] = scores * (1 - enemy_safe_distance[target_id] * 0.3)
                best_target_id = max(target_scores.items(), key=lambda x: x[1])[0]
                best_target = next(enemy for enemy in obs_data.enemies if enemy.id == best_target_id)
                if best_target.can_attack:
                    return attack(best_target_id)
                else:
                    # Best target is not in shoot range, move to target
                    dx = best_target.position[0]
                    dy = best_target.position[1]

                    # Only move to target if its direction is not conflicting with the safe point
                    # Check if target direction aligns with safe point direction
                    target_angle = math.atan2(dy, dx)
                    safe_angle = math.atan2(ally_y, ally_x)
                    angle_diff = abs(target_angle - safe_angle)
                    # Only move if angle difference is less than 90 degrees
                    if angle_diff < math.pi/9 or angle_diff > 17*math.pi/9 or not melee_ally:
                        if best_target.distance > obs_data.own_shoot_range / obs_data.own_sight_range:
                            path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                            if path_action:
                                return path_action
                    if attack_actions:
                        if obs_data.last_action in attack_actions:
                            return obs_data.last_action
                        attackable_enemies = [enemy for enemy in obs_data.enemies if enemy.can_attack]
                        closest_enemy = min(
                            [enemy for enemy in attackable_enemies],
                            key=lambda enemy: enemy.distance,
                            default=None,
                        )
                        if closest_enemy and closest_enemy.can_attack:
                            return attack(closest_enemy.id)
                        return random.choice(attack_actions)
                    else:
                        if distance > criterion:
                            path_action = find_path(obs_data, safe_x, safe_y)
                            if path_action:
                                return path_action
                    
            # If there are only enemies
            else:
                # Closest enemy as target
                closest_enemy = min(
                    [enemy for enemy in obs_data.enemies],
                    key=lambda enemy: enemy.distance,
                    default=None,
                )
                # No allies, kitting melee enemies
                if closest_enemy.unit_type.lower() in ['zealot', 'zergling', 'baneling']:
                    if closest_enemy.distance <= 4 / obs_data.own_sight_range and obs_data.last_action >= 6:
                        enemy_x = sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies)
                        enemy_y = sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies)
                        target_x = - enemy_x
                        target_y = - enemy_y

                        g_x = target_x * obs_data.own_sight_range + (obs_data.own_position[0] * 32)
                        g_y = target_y * obs_data.own_sight_range + (obs_data.own_position[1] * 32)
                        if not (0 <= g_x <= 32 and 0 <= g_y <= 32):
                            target_x = (0.5 - obs_data.own_position[0]) * 32 / obs_data.own_sight_range
                            target_y = (0.5 - obs_data.own_position[1]) * 32 / obs_data.own_sight_range

                        path_action = find_path(obs_data, target_x, target_y)
                        if path_action:
                            return path_action
                    if closest_enemy.can_attack:
                        return attack(closest_enemy.id)
                else:
                    # No melee enemies, highest priority enemy as target
                    target_scores = {enemy.id: score_target(enemy) for enemy in obs_data.enemies}
                    # Check if there are same max score targets
                    max_score = max(target_scores.values())
                    max_score_target_ids = [target_id for target_id, score in target_scores.items() if score == max_score]
                    max_score_targets = [enemy for enemy in obs_data.enemies if enemy.id in max_score_target_ids]
                    if len(max_score_targets) > 1:
                        # Chose the closest target
                        best_target = min(max_score_targets, key=lambda x: x.distance)
                    else:
                        best_target = max_score_targets[0]
                    if best_target.can_attack:
                        return attack(best_target.id)
                    else:
                        # Best target is not in shoot range, move to target
                        dx = best_target.position[0]
                        dy = best_target.position[1]
                        if best_target.distance > obs_data.own_shoot_range / obs_data.own_sight_range:
                            path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                            if path_action:
                                return path_action
                            elif attack_actions:
                                if obs_data.last_action in attack_actions:
                                    return obs_data.last_action
                                attackable_enemies = [enemy for enemy in obs_data.enemies if enemy.can_attack]
                                closest_enemy = min(
                                    [enemy for enemy in attackable_enemies],
                                    key=lambda enemy: enemy.distance,
                                    default=None,
                                )
                                if closest_enemy and closest_enemy.can_attack:
                                    return attack(closest_enemy.id)
                                return random.choice(attack_actions)
        # If there are no enemies
        else:
            # If there are only allies
            if obs_data.allies:
                # Check if any melee ally
                melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                if melee_ally:
                    # Move to center of melee allies
                    melee_ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally)
                    melee_ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally)
                    dx = melee_ally_x
                    dy = melee_ally_y
                    distance = (dx ** 2 + dy ** 2) ** 0.5
                    if distance > 0.05:
                        path_action = find_path(obs_data, dx, dy)
                        if path_action:
                            return path_action
                else:
                    # No melee allies, move to target ally
                    ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies)
                    ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies)
                    distance = (ally_x ** 2 + ally_y ** 2) ** 0.5
                    if distance > 0.05:
                        dx = ally_x
                        dy = ally_y
                        path_action = find_path(obs_data, dx, dy)
                        if path_action:
                            return path_action
        return default_action(obs)
    return control_logic()

__all__ = [
    "race_protoss_colossus_navi_A_star_score_area_control_default_protected",
]

if __name__ == "__main__":
    from harl.test import test_params
    # test each function in this py file
    for skill in __all__:
        print(f"Testing {skill}...")
        for text_id, test_param in enumerate(test_params):
            print(f"Test param: {text_id}")
            print(f"Result: {globals()[skill](**test_param)}")
    # for test_id, test_param in enumerate(test_params):
    # print(race_medivac_melee_ranged_navi_A_star_score_type_default_center(**test_params[-1]))