from harl.common.skills.smacv2.atomic_actions.combat import attack
from harl.common.skills.smacv2.atomic_actions.move import move_north, move_south, move_east, move_west
from harl.common.skills.smacv2.atomic_actions.basic import stop
from harl.common.skills.smacv2.atomic_actions.heal import heal
from harl.utils.skill_utils import parse_obs
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Set
import random

# Pathfinder for target position with unit positions as obstacles (Generate New)
def find_path(obs_data, target_x: int, target_y: int, target_type='point', timeout=1e5):
    """A* pathfinding considering unit positions as obstacles"""
    import heapq
    import math
    from dataclasses import dataclass

    valid_actions = obs_data.available_actions
    @dataclass
    class Node:
        x: int
        y: int
        direction: tuple = None  # Store movement direction
        g_cost: float = float('inf')  # Cost from start to current
        h_cost: float = float('inf')  # Estimated cost to goal
        parent: 'Node' = None
        
        @property
        def f_cost(self):
            return self.g_cost + self.h_cost
        
        def __lt__(self, other):
            return self.f_cost < other.f_cost

    # Manhattan distance heuristic
    def manhattan_distance(x1, y1, x2, y2):
        return abs(x1 - x2) + abs(y1 - y2)
    
    def euclidean_distance(x1, y1, x2, y2):
        return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5

    # Own position is own_global position / map size 32*32, target position is relative position (target_global_position - own_global_position) / own_sight_range
    # Transform start and target to global position
    start_x = obs_data.own_position[0] * 32
    start_y = obs_data.own_position[1] * 32
    target_x = target_x * obs_data.own_sight_range + start_x
    target_y = target_y * obs_data.own_sight_range + start_y
    start_distance = manhattan_distance(start_x, start_y, target_x, target_y)

    # Starcraft 2 unit collision radius
    unit_radius = {
        'marine': 0.275,
        'hydralisk': 0.525,
        'stalker': 0.525,
        'marauder': 0.4625,
        'baneling': 0.275,  # Extremely dangerous splash damage.
        'zealot': 0.425,
        'zergling': 0.275,
        'medivac': 0.65,   # High-priority support unit.
        'colossus': 0.9,  # Devastating long-range AOE damage.
        }

    own_radius = unit_radius.get(obs_data.own_unit_type.lower(), 0.5)
    target_radius = unit_radius.get(target_type, 0)

    # New: Threat radius and weights
    threat_radius = {
        'colossus': 9,  # Long range threat
        'hydralisk': 7,
        'stalker': 8,
        'marauder': 8,
        'baneling': 4,  # Explosive threat
        'zealot': 4,
        'zergling': 4,
        'marine': 7,
        'medivac': 0,
    }

    # Collect unit positions as obstacles
    unit_positions = set()
    threat_zones = set()
    ally_positions = set()
    for ally in obs_data.allies:
        x = ally.position[0] * obs_data.own_sight_range + start_x
        y = ally.position[1] * obs_data.own_sight_range + start_y
        if target_type != 'point':
            if x == target_x and y == target_y:
                continue
        distance = euclidean_distance(x, y, start_x, start_y)
        ally_radius = unit_radius.get(ally.unit_type.lower(), 0.5)
        if ally.last_action >= 6:
            unit_positions.add((x,y,ally_radius))
        ally_positions.add((x,y,ally_radius))
    for enemy in obs_data.enemies:
        x = enemy.position[0] * obs_data.own_sight_range + start_x
        y = enemy.position[1] * obs_data.own_sight_range + start_y
        if target_type != 'point':
            if x == target_x and y == target_y:
                continue
        distance = euclidean_distance(x, y, start_x, start_y)
        enemy_radius = unit_radius.get(enemy.unit_type.lower(), 0.5)
        unit_positions.add((x,y,enemy_radius))
        enemy_threat = threat_radius.get(enemy.unit_type.lower(), 0)
        if enemy_threat > 0:
            threat_zones.add((x, y, enemy_threat, enemy.unit_type.lower()))

    def is_target_accessible(target_x, target_y, unit_positions):
        """Validate if target position has any accessible approach points"""
        move_amount = own_radius + target_radius
        directions = [(0,move_amount), (move_amount,0), (0,-move_amount), (-move_amount,0), 
                        (move_amount,move_amount), (-move_amount,move_amount), (move_amount,-move_amount), (-move_amount,-move_amount)]
        if target_type == 'point':
            directions += [(0,0)]
        for dx, dy in directions:
            check_x = target_x + dx
            check_y = target_y + dy
            
            if 0 <= check_x < 32 and 0 <= check_y < 32:
                blocked = False
                for x, y, radius in unit_positions:
                    if euclidean_distance(check_x, check_y, x, y) < radius + own_radius:
                        blocked = True
                        break
                if not blocked:
                    return True
        return False
    
    def calculate_threat_cost(x, y, threat_zones):
        """Calculate threat cost for a position based on nearby threats"""
        total_threat = 0
        for tx, ty, tradius, ttype in threat_zones:
            dist = euclidean_distance(x, y, tx, ty)
            if dist <= tradius:
                # Higher threat for closer positions
                threat_weight = math.exp(-dist / (tradius))
                total_threat += threat_weight
        return total_threat * 2.5  # Threat weight
    
    def get_movement_direction(last_action):
        """Convert last action to direction vector"""
        direction_map = {
            2: (0, 1),   # North
            3: (0, -1),  # South
            4: (1, 0),   # East
            5: (-1, 0),  # West
        }
        return direction_map.get(last_action, (0, 0))

    def get_direction_cost(current_dir, new_dir, is_first_move=False):
        """Calculate direction change cost"""
        
        if not current_dir:
            return 0
        
        # No direction change
        if current_dir == new_dir:
            return 0
        
        # 180-degree turn (complete reversal)
        if current_dir[0] == -new_dir[0] and current_dir[1] == -new_dir[1]:
            if is_first_move:
                return 1.5
            return 1
            
        # 90-degree turn
        return 0.5
    
    def calculate_flocking_costs(x, y, direction):
        """Calculate costs related to flocking behavior"""
        alignment = 0
        cohesion = 0
        
        # Calculate alignment with nearby allies
        for ax, ay, _ in ally_positions:
            dist = euclidean_distance(x, y, ax, ay)
            if dist < 5:  # Alignment range
                ax_dir = (ax - x, ay - y)
                # Normalize directions
                magnitude = (ax_dir[0]**2 + ax_dir[1]**2)**0.5
                if magnitude > 0:
                    ax_dir = (ax_dir[0]/magnitude, ax_dir[1]/magnitude)
                    if direction:
                        alignment += abs(direction[0] - ax_dir[0]) + abs(direction[1] - ax_dir[1])

        # Calculate cohesion with group
        avg_x = sum(ax for ax, _, _ in ally_positions) / len(ally_positions) if ally_positions else x
        avg_y = sum(ay for _, ay, _ in ally_positions) / len(ally_positions) if ally_positions else y
        cohesion = euclidean_distance(x, y, avg_x, avg_y)
        # Weight cohesion with the distance to target, cohesion is more important at start, exponentially decaying
        # Use exponential decay like ExponentialLR in PyTorch with gamma=0.95
        # Since distance gets closer and closer, using (1 - progress) to make weight smaller as we get closer
        gamma = 4
        progress = 1 - euclidean_distance(x, y, target_x, target_y) / start_distance
        cohesion_weight = math.exp(-gamma * progress)
        cohesion *= cohesion_weight

        return alignment, cohesion

    # Get valid neighboring positions avoiding units
    def get_neighbors(current: Node, unit_positions, steps):
        """Get valid neighboring positions avoiding units"""
        neighbors = []
            # Primary directions
        directions = [(0,1), (1,0), (0,-1), (-1,0)]

        is_first_move = False
        if steps == 0:
            is_first_move = True

        # Combine directions with preference for maintaining current direction
        all_directions = directions
        if current.direction:
            # Prioritize directions similar to current movement
            directions.sort(key=lambda d: get_direction_cost(current.direction, d, is_first_move))
            all_directions = directions
        
        for dx, dy in all_directions:
            new_x, new_y = current.x + dx, current.y + dy

            # Check if position is within map bounds map_size = 32*32
            if 0 <= new_x < 32 and 0 <= new_y < 32:
                # Check if position is blocked by unit, position is float
                
                distances = [euclidean_distance(new_x, new_y, x, y) < (own_radius + radius) for x, y, radius in unit_positions]
                if any(distances):
                    continue
                else:
                    neighbor = Node(new_x, new_y, direction=(dx,dy))
                    # Add direction change cost
                    direction_cost = get_direction_cost(current.direction, (dx,dy), is_first_move)
                    threat_cost = calculate_threat_cost(new_x, new_y, threat_zones)
                    alignment, cohesion = calculate_flocking_costs(new_x, new_y, (dx,dy))
                    base_cost = 1.0
                    neighbor.g_cost = current.g_cost + base_cost + direction_cost + threat_cost + alignment + cohesion
                    neighbors.append(neighbor)
                    
        return neighbors

    # Initialize A* 
    start_direction = get_movement_direction(obs_data.last_action)
    start = Node(start_x, start_y, direction=start_direction)
    start.g_cost = 0
    start.h_cost = start_distance
    
    open_set = [start]
    closed_set = set()

    # Check target accessibility before pathfinding
    if not is_target_accessible(target_x, target_y, unit_positions):
        return None  # Target is completely surrounded/blocked
    
    steps = 0
    while open_set and steps < timeout:
        current = heapq.heappop(open_set)
        
        # Check if target within current shoot range
        
        distance = euclidean_distance(target_x, target_y, current.x, current.y)
        criteria = 2 if target_type == 'point' else obs_data.own_shoot_range
        if distance <= criteria:
            # Found path, get first move
            while current.parent and current.parent != start:
                current = current.parent
            # Convert move to action
            dx, dy = current.direction
            if abs(dx) > abs(dy):
                if dx > 0 and 4 in valid_actions:
                    return move_east()
                elif dx < 0 and 5 in valid_actions:
                    return move_west()
            else:  
                if dy > 0 and 2 in valid_actions:
                    return move_north()
                elif dy < 0 and 3 in valid_actions:
                    return move_south()
            return None
                
        closed_set.add((current.x, current.y))
        
        for neighbor in get_neighbors(current, unit_positions, steps):
            if (neighbor.x, neighbor.y) in closed_set:
                continue
                                
            if neighbor not in open_set:
                neighbor.parent = current
                neighbor.h_cost = manhattan_distance(neighbor.x, neighbor.y, target_x, target_y)
                heapq.heappush(open_set, neighbor)
        steps += 1
                    
    return None

def default_tactic(obs: str):
    import math
    # Parse observation
    obs_data = parse_obs(obs)
    # Get set of available actions
    valid_actions = obs_data.available_actions

    if 0 in valid_actions:
        return 0
    
    def score_target(unit):
        """Enhanced target scoring with improved kiting and formation control"""
        if unit.health <= 0:
            return -1
                    
        score = 0
        
        # Refined unit type priorities with enhanced threat scaling
        unit_priorities = {
            'colossus': 35.0,  # Further increased priority
            'stalker': 30.0,   # Enhanced anti-armor focus
            'zealot': 45.0,    # Higher melee threat recognition

            'marine': 45.0,    # Balanced damage dealer priority
            'marauder': 35.0,  # Anti-armor specialist
            'medivac': 30.0,   # Support unit priority

            'hydralisk': 30.0,  # High priority for their sustained DPS
            'zergling': 35.0,   # Medium priority as swarm units
            'baneling': 45.0,   # Critical priority due to splash damage
        }
        
        # Dynamic matchup priorities with improved counter weighting
        unit_counters = {
            'colossus': {'colossus': 1.2, 'stalker': 1.0, 'zealot': 1.5},
            'stalker': {'colossus': 1.2, 'stalker': 1.0, 'zealot': 1.5},
            'zealot': {'colossus': 1.2, 'stalker': 1.0, 'zealot': 1.5},

            'marine': {'marine': 1.5, 'medivac': 1.0, 'marauder': 1.2},
            'marauder': {'marine': 1.5, 'medivac': 1.0, 'marauder': 1.2},
            'medivac': {'marine': 1.5, 'medivac': 1.0, 'marauder': 1.2},

            'hydralisk': {'hydralisk': 1.0, 'zergling': 1.2, 'baneling': 1.5},
            'zergling': {'hydralisk': 1.2, 'zergling': 1.5, 'baneling': 1.0},
            'baneling': {'hydralisk': 1.2, 'zergling': 1.5, 'baneling': 1.0},

        }
        
        base_priority = unit_priorities.get(unit.unit_type.lower(), 5.0)
        is_ranged = unit.unit_type.lower() not in ['zealot', 'zergling', 'baneling']
        own_is_ranged = obs_data.own_unit_type.lower() not in ['zealot', 'zergling', 'baneling']
        
        # Enhanced threat assessment with improved melee handling
        matchup_mult = unit_counters.get(obs_data.own_unit_type.lower(), {}).get(unit.unit_type.lower(), 1.0)
        base_priority *= matchup_mult
        
        if hasattr(unit, 'can_attack'):  # Enemy unit
            score = base_priority

            distance_factor = max((1 - unit.distance) + 1, 0.5)
            score *= distance_factor

            if not own_is_ranged:
                range_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() not in ['zealot', 'zergling', 'baneling']]
                if range_ally:
                    ally_x = sum(ally.position[0] for ally in range_ally) / len(range_ally)
                    ally_y = sum(ally.position[1] for ally in range_ally) / len(range_ally)
                    ally_distance = ((ally_x - unit.position[0])**2 + (ally_y - unit.position[1])**2)**0.5
                    distance_factor = max((1 - ally_distance) + 1, 0.5)
                    score *= distance_factor
            
            # Enhanced Position Analysis with improved spacing
            position_x, position_y = unit.position
            
            def calculate_combat_power(units, radius=0.5):  # Further reduced for tighter control
                total_power = 0
                ranged_count = 0
                melee_count = 0
                unit_positions = []
                
                for u in units:
                    dist = ((u.position[0] - position_x)**2 + 
                        (u.position[1] - position_y)**2)**0.5
                    unit_positions.append(u.position)
                    
                    if dist <= radius:
                        base_power = unit_priorities.get(u.unit_type.lower(), 5.0)
                        
                        # Unit type specific power calculation
                        if u.unit_type.lower() in ['zealot', 'zergling', 'baneling']:
                            melee_count += 1
                            if melee_count >= 2:
                                base_power *= 1.4
                        else:
                            ranged_count += 1
                            base_power *= 1.3
                        
                        # Health-based power scaling
                        health_factor = 1.5 if u.health > 0.7 else 1.0 if u.health > 0.4 else 0.6
                        position_factor = 1.3 - (dist/radius)
                        
                        total_power += base_power * health_factor * position_factor
                
                # Enhanced formation cohesion calculation
                cohesion = 0
                if len(unit_positions) > 2:
                    center_x = sum(p[0] for p in unit_positions) / len(unit_positions)
                    center_y = sum(p[1] for p in unit_positions) / len(unit_positions)
                    avg_dist = sum(((p[0] - center_x)**2 + (p[1] - center_y)**2)**0.5 
                                for p in unit_positions) / len(unit_positions)
                    max_desired_dist = 0.3  # Tighter formation control
                    cohesion = 2.0 / (1.0 + (avg_dist / max_desired_dist))
                
                return total_power * (1 + cohesion), melee_count, ranged_count
            
            ally_power, ally_swarms, ally_ranged = calculate_combat_power(obs_data.allies)
            enemy_power, enemy_swarms, enemy_ranged = calculate_combat_power(obs_data.enemies)
            
            # Improved Focus Fire Logic with enhanced commitment
            num_attackers = sum(1 for ally in obs_data.allies 
                            if ally.last_action >= 6 and ally.last_action - 6 == unit.id)
            
            if unit.id == obs_data.last_action - 6:
                persistence_bonus = 2.0  # Stronger target commitment
                score *= persistence_bonus
            
            if num_attackers > 0:
                focus_bonus = 1.2 ** num_attackers  # Enhanced focus fire emphasis
                # if num_attackers too high, discourage prevent overcommitment
                if num_attackers >= 3 and unit.id != obs_data.last_action - 6 and obs_data.own_unit_type.lower() in ['zergling', 'baneling']:
                    focus_bonus = 0.5
                score *= focus_bonus
            
            # Improved Combat Advantage Factor
            advantage_factor = 1.0
            if ally_power > enemy_power * 1.3:
                advantage_factor = 1.2  # More aggressive advantage pursuit
                if ally_swarms >= 3:
                    advantage_factor *= 1.2
            
            # prioritize isolated enemies
            if (enemy_swarms + enemy_ranged) == 1:
                advantage_factor *= 2.0
            elif (ally_swarms + ally_ranged) > (enemy_swarms + enemy_ranged):
                advantage_factor *= 1.2

            score *= advantage_factor
            
            # Health factor
            health_factor = (1 - unit.health) + 1
            score *= health_factor
            
        else:  # Ally unit
            score = base_priority
            
            # Improved Support Priority
            health_factor = (1 - unit.health) + 1
            score *= health_factor
            
            distance_factor = max((1 - unit.distance) + 1, 0.5)
            score *= distance_factor
            
        return score
    
    def default_action():
        """Default action function"""
        target_scores = {ally.id: score_target(ally) for ally in obs_data.allies}
        ally_lookup = {ally.id: ally for ally in obs_data.allies}
        # sort ally
        sorted_target_scores = sorted(target_scores.items(), key=lambda x: x[1], reverse=True)
        
        # follow the highest score target's last action
        for ally_id, score in sorted_target_scores:
            ally = ally_lookup[ally_id]
            if 2 <= ally.last_action:
                if ally.last_action in valid_actions:
                    return ally.last_action

        # Finally random choice from valid actions
        # Filter valid movement actions to avoid inefficient random movements
        valid_movement_actions = []                
        for action in valid_actions:
            if 2 <= action <= 5:  # Movement actions
                valid_movement_actions.append(action)
                
        # use movement with persistence
        if valid_movement_actions:
            # Try to continue previous direction if possible
            if obs_data.last_action:
                last_action = obs_data.last_action
                if 2 <= last_action <= 5 and last_action in valid_movement_actions:
                    action = last_action
                else:
                    action = random.choice(valid_movement_actions)
            else:
                action = random.choice(valid_movement_actions)
        else:
            action = 1  # Stop if no valid moves
        
        # Convert action ID to atomic action function call
        if action == 1:
            return stop()
        elif action == 2:
            return move_north()
        elif action == 3:
            return move_south()
        elif action == 4:
            return move_east()
        elif action == 5:
            return move_west()
        elif action >= 6:  # Attack or heal action
            target_id = action - 6
            if obs_data.own_unit_type.lower() == 'medivac':
                return heal(target_id)
            else:
                return attack(target_id)

    def control_logic():
        # Medivac units control logic
        if obs_data.own_unit_type.lower() == 'medivac':
            attack_actions = [a for a in valid_actions if a >= 6]
            # If there are allies
            if obs_data.allies:
                lowest_health_ally = min(obs_data.allies, key=lambda x: x.health)
                # If there are both allies and enemies
                if obs_data.enemies:
                    enemy_in_range = [enemy for enemy in obs_data.enemies if enemy.distance < 1]
                    # Check if any melee ally, if so and last action is not attack, move to center of melee allies
                    melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                    if melee_ally:
                        # Move to center of melee allies
                        ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally)
                        ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally)
                    else:
                        ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies)
                        ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies)
                    # Calculate retreat position
                    if len(enemy_in_range) > 0:
                        enemy_center = (sum(e.position[0] for e in enemy_in_range) / len(enemy_in_range),
                                    sum(e.position[1] for e in enemy_in_range) / len(enemy_in_range))
                    else:
                        enemy_center = (sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies),
                                    sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies))
                    
                    dx = ally_x - enemy_center[0]
                    dy = ally_y - enemy_center[1]

                    distance = (dx ** 2 + dy ** 2) ** 0.5

                    safe_x = ally_x + (dx / abs(dx)) * 2/obs_data.own_sight_range if dx != 0 else ally_x
                    safe_y = ally_y + (dy / abs(dy)) * 2/obs_data.own_sight_range if dy != 0 else ally_y

                    distance = (safe_x ** 2 + safe_y ** 2) ** 0.5
                    criterion = 5/obs_data.own_sight_range
                    if obs_data.last_action >= 6 or len(attack_actions) == 0:
                        target_angle = math.atan2(enemy_center[1], enemy_center[0])
                        safe_angle = math.atan2(ally_y, ally_x)
                        angle_diff = abs(target_angle - safe_angle)
                        if distance > criterion and (math.pi/9 < angle_diff < 17*math.pi/9):
                            path_action = find_path(obs_data, safe_x, safe_y)
                            if path_action:
                                return path_action
                target_scores = {ally.id: score_target(ally) for ally in obs_data.allies}
                # Check if there are same max score targets
                max_score = max(target_scores.values())
                max_score_target_ids = [target_id for target_id, score in target_scores.items() if score == max_score]
                max_score_targets = [ally for ally in obs_data.allies if ally.id in max_score_target_ids]
                closest_ally = min(obs_data.allies, key=lambda x: x.distance)
                if len(max_score_targets) > 1:
                    # Chose the closest target
                    best_target = min(max_score_targets, key=lambda x: x.distance)
                else:
                    best_target = max_score_targets[0]
                if (best_target.id + 6) in valid_actions and 0 < best_target.health < 0.9:
                    return heal(best_target.id)
                elif (closest_ally.id + 6) in valid_actions and 0 < closest_ally.health < 0.9:
                    return heal(closest_ally.id)
                elif (lowest_health_ally.id + 6) in valid_actions and 0 < lowest_health_ally.health < 0.9:
                    return heal(lowest_health_ally.id)
                else:
                    # Move to the target
                    dx = best_target.position[0]
                    dy = best_target.position[1]
                    path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                    if path_action:
                        return path_action
            # If there are no allies
            else:
                # If there are only enemies
                if obs_data.enemies:
                    enemy_x = sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies)
                    enemy_y = sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies)
                    target_x = - enemy_x
                    target_y = - enemy_y

                    g_x = target_x * obs_data.own_sight_range + (obs_data.own_position[0] * 32)
                    g_y = target_y * obs_data.own_sight_range + (obs_data.own_position[1] * 32)
                    if not (0 <= g_x <= 32 and 0 <= g_y <= 32):
                        target_x = (0.5 - obs_data.own_position[0]) * 32 / obs_data.own_sight_range
                        target_y = (0.5 - obs_data.own_position[1]) * 32 / obs_data.own_sight_range

                    path_action = find_path(obs_data, target_x, target_y)
                    if path_action:
                        return path_action
        # Melee units control logic
        elif obs_data.own_unit_type.lower() in ['zealot', 'zergling', 'baneling']:
            # If there are enemies
            if obs_data.enemies:
                enemy_in_range = [enemy for enemy in obs_data.enemies if enemy.distance < 1]
                attack_actions = [a for a in valid_actions if a >= 6]

                if len(enemy_in_range) > 0:
                    enemy_center = (sum(e.position[0] for e in enemy_in_range) / len(enemy_in_range),
                                sum(e.position[1] for e in enemy_in_range) / len(enemy_in_range))
                else:
                    enemy_center = (sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies),
                                sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies))
                if obs_data.allies:
                    # Check if any melee ally, if so and last action is not attack, move to center of melee allies
                    melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                    if melee_ally:
                        # Move to center of melee allies
                        ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally) / 2
                        ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally) / 2

                        safe_x = ally_x 
                        safe_y = ally_y

                        distance = (safe_x ** 2 + safe_y ** 2) ** 0.5
                        criterion = 2/obs_data.own_sight_range
                        if len(attack_actions) == 0 or distance > 0.5:
                            target_angle = math.atan2(enemy_center[1], enemy_center[0])
                            safe_angle = math.atan2(ally_y, ally_x)
                            angle_diff = abs(target_angle - safe_angle)
                            if distance > criterion and (math.pi/9 < angle_diff < 17*math.pi/9 or distance > 0.5):
                                path_action = find_path(obs_data, safe_x, safe_y)
                                if path_action:
                                    return path_action
                            
                # Enhanced cluster detection with dynamic radius
                enemy_clusters = {}
                cluster_centers = {}
                for enemy in obs_data.enemies:
                    nearby_enemies = []
                    center_x, center_y = enemy.position[0], enemy.position[1]
                    
                    # Dynamic cluster radius based on unit type
                    cluster_radius = 0.3 if obs_data.own_unit_type.lower() == 'baneling' else 0.2
                    
                    for other in obs_data.enemies:
                        distance = ((other.position[0] - enemy.position[0])**2 + 
                                (other.position[1] - enemy.position[1])**2)**0.5
                        if distance <= cluster_radius:
                            nearby_enemies.append(other)
                            center_x += other.position[0]
                            center_y += other.position[1]
                    
                    if nearby_enemies:
                        center_x /= len(nearby_enemies)
                        center_y /= len(nearby_enemies)
                        
                    enemy_clusters[enemy.id] = len(nearby_enemies)
                    cluster_centers[enemy.id] = (center_x, center_y)

                # Enhanced target scoring with tactical considerations
                target_scores = {}
                for enemy in obs_data.enemies:
                    base_score = score_target(enemy)
                    
                    # Enhanced cluster bonus for splash damage
                    if obs_data.own_unit_type.lower() == 'baneling':
                        cluster_bonus = 1.5 ** enemy_clusters[enemy.id]
                    else:
                        cluster_bonus = 1.2 ** enemy_clusters[enemy.id]
                    # Calculate final score with all factors
                    target_scores[enemy.id] = (base_score + cluster_bonus)

                # Check if there are same max score targets
                max_score = max(target_scores.values())
                max_score_targets = [enemy for enemy in obs_data.enemies 
                           if target_scores[enemy.id] >= max_score]  # Allow for close scores
                if len(max_score_targets) > 1:
                    # Choose target balancing distance and cluster potential
                    best_target = min(max_score_targets, 
                            key=lambda x: x.distance)
                else:
                    best_target = max_score_targets[0]

                if best_target.can_attack:
                    return attack(best_target.id)
                else:
                    # Move to the target
                    dx = best_target.position[0]
                    dy = best_target.position[1]
                    path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                    if path_action:
                        return path_action
                    elif attack_actions:
                        attackable_enemies = [enemy for enemy in obs_data.enemies if enemy.can_attack]
                        if obs_data.last_action in attack_actions:
                            return obs_data.last_action
                        if attackable_enemies:
                            return attack(min(attackable_enemies, key=lambda e: e.distance).id)
                        return random.choice(attack_actions)

            # If there are no enemies
            else:
                # If there are only allies
                if obs_data.allies:
                    # Improved melee group formation
                    melee_allies = [ally for ally in obs_data.allies 
                                if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                    if melee_allies:
                        spacing = 0.1 if obs_data.own_unit_type.lower() == 'baneling' else 0.05
                        # Dynamic group positioning
                        center_x = sum(ally.position[0] for ally in melee_allies) / len(melee_allies)
                        center_y = sum(ally.position[1] for ally in melee_allies) / len(melee_allies)
                        
                        # Calculate spread from center
                        max_spread = max(((ally.position[0] - center_x)**2 + 
                                        (ally.position[1] - center_y)**2)**0.5 
                                    for ally in melee_allies)
                        
                        own_distance = ((center_x)**2 + (center_y)**2)**0.5
                        
                        if own_distance > spacing or max_spread > 0.1:
                            # Move toward center while maintaining minimum spacing
                            adjusted_x = center_x * 0.85  # Slight offset to prevent overcrowding
                            adjusted_y = center_y * 0.85
                            path_action = find_path(obs_data, adjusted_x, adjusted_y)
                            if path_action:
                                return path_action
                    else:
                        ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies)
                        ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies)
                        distance = (ally_x ** 2 + ally_y ** 2) ** 0.5
                        if distance > 0.05:
                            dx = ally_x
                            dy = ally_y
                            path_action = find_path(obs_data, dx, dy)
                            if path_action:
                                return path_action
        # Ranged units control logic
        else:
            attack_actions = [a for a in valid_actions if a >= 6]
            # If there are enemies
            if obs_data.enemies:
                # If there are both allies and enemies
                # Calculate retreat position
                enemy_in_range = [enemy for enemy in obs_data.enemies if enemy.distance < 1]
                if len(enemy_in_range) > 0:
                    enemy_center = (sum(e.position[0] for e in enemy_in_range) / len(enemy_in_range),
                                sum(e.position[1] for e in enemy_in_range) / len(enemy_in_range))
                else:
                    enemy_center = (sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies),
                                sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies))
                if obs_data.allies:
                    # Check if any melee ally, if so and last action is not attack, move to center of melee allies
                    melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                    melee_enemy = [enemy for enemy in enemy_in_range if enemy.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                    if melee_ally:
                        # Move to center of melee allies
                        ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally)
                        ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally)
                        
                    else:
                        ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies) / 2
                        ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies) / 2

                    dx = ally_x - enemy_center[0]
                    dy = ally_y - enemy_center[1]
                    safe_x = ally_x
                    safe_y = ally_y
                    if melee_ally:
                        safe_x = safe_x + (dx / abs(dx)) * 2/obs_data.own_sight_range if dx != 0 else safe_x
                        safe_y = safe_y + (dy / abs(dy)) * 2/obs_data.own_sight_range if dy != 0 else safe_y
                    melee_threaten = False
                    if melee_enemy:
                        closest_melee_enemy = min(melee_enemy, key=lambda x: x.distance)
                        if closest_melee_enemy.distance <= 4/obs_data.own_sight_range:
                            melee_threaten = True
                            dx = safe_x - closest_melee_enemy.position[0]
                            dy = safe_y - closest_melee_enemy.position[1]
                            safe_x = safe_x + (dx / abs(dx)) * 1/obs_data.own_sight_range if dx != 0 else safe_x
                            safe_y = safe_y + (dy / abs(dy)) * 1/obs_data.own_sight_range if dy != 0 else safe_y
                    
                    distance = (safe_x ** 2 + safe_y ** 2) ** 0.5
                    criterion = 4/obs_data.own_sight_range
                    if obs_data.last_action >= 6 or len(attack_actions) == 0 or distance > 0.9:
                        target_angle = math.atan2(enemy_center[1], enemy_center[0])
                        safe_angle = math.atan2(ally_y, ally_x)
                        angle_diff = abs(target_angle - safe_angle)
                        if distance > criterion and ((math.pi/9 < angle_diff < 17*math.pi/9) or melee_threaten or distance > 0.9):
                            path_action = find_path(obs_data, safe_x, safe_y)
                            if path_action:
                                return path_action
                    # Focus fire logic
                    # Count how many allies are attacking each enemy
                    target_counts = {}
                    for ally in obs_data.allies:
                        if ally.last_action >= 6:
                            target_id = ally.last_action - 6
                            target_counts[target_id] = target_counts.get(target_id, 0) + 1
                    # Distance to safe point of each enemy affect target choosing
                    enemy_safe_distance = {enemy.id: ((enemy.position[0] - safe_x) ** 2 + (enemy.position[1] - safe_y) ** 2) ** 0.5 for enemy in obs_data.enemies}
                    # Find best target combining focus fire and threat scoring
                    target_scores = {enemy.id: score_target(enemy) for enemy in obs_data.enemies}
                    for target_id, count in target_counts.items():
                        if target_id in target_scores:
                            target_scores[target_id] += count * 0.5
                    for target_id, scores in target_scores.items():
                        target_scores[target_id] = scores * (1 - enemy_safe_distance[target_id] * 0.3)
                    best_target_id = max(target_scores.items(), key=lambda x: x[1])[0]
                    best_target = next(enemy for enemy in obs_data.enemies if enemy.id == best_target_id)
                    if best_target.can_attack:
                        return attack(best_target_id)
                    else:
                        # Best target is not in shoot range, move to target
                        dx = best_target.position[0]
                        dy = best_target.position[1]

                        # Only move to target if its direction is not conflicting with the safe point
                        # Check if target direction aligns with safe point direction
                        target_angle = math.atan2(dy, dx)
                        safe_angle = math.atan2(ally_y, ally_x)
                        angle_diff = abs(target_angle - safe_angle)
                        # Only move if angle difference is less than 90 degrees
                        if angle_diff < math.pi/9 or angle_diff > 17*math.pi/9 or not melee_ally:
                            if best_target.distance > obs_data.own_shoot_range / obs_data.own_sight_range:
                                path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                                if path_action:
                                    return path_action
                        if attack_actions:
                            if obs_data.last_action in attack_actions:
                                return obs_data.last_action
                            attackable_enemies = [enemy for enemy in obs_data.enemies if enemy.can_attack]
                            closest_enemy = min(
                                [enemy for enemy in attackable_enemies],
                                key=lambda enemy: enemy.distance,
                                default=None,
                            )
                            if closest_enemy and closest_enemy.can_attack:
                                return attack(closest_enemy.id)
                            return random.choice(attack_actions)
                        else:
                            if distance > criterion:
                                path_action = find_path(obs_data, safe_x, safe_y)
                                if path_action:
                                    return path_action
                        
                # If there are only enemies
                else:
                    # Closest enemy as target
                    closest_enemy = min(
                        [enemy for enemy in obs_data.enemies],
                        key=lambda enemy: enemy.distance,
                        default=None,
                    )
                    # No allies, kitting melee enemies
                    if closest_enemy.unit_type.lower() in ['zealot', 'zergling', 'baneling']:
                        if closest_enemy.distance <= 4 / obs_data.own_sight_range and obs_data.last_action >= 6:
                            enemy_x = sum(e.position[0] for e in obs_data.enemies) / len(obs_data.enemies)
                            enemy_y = sum(e.position[1] for e in obs_data.enemies) / len(obs_data.enemies)
                            target_x = - enemy_x
                            target_y = - enemy_y

                            g_x = target_x * obs_data.own_sight_range + (obs_data.own_position[0] * 32)
                            g_y = target_y * obs_data.own_sight_range + (obs_data.own_position[1] * 32)
                            if not (0 <= g_x <= 32 and 0 <= g_y <= 32):
                                target_x = (0.5 - obs_data.own_position[0]) * 32 / obs_data.own_sight_range
                                target_y = (0.5 - obs_data.own_position[1]) * 32 / obs_data.own_sight_range

                            path_action = find_path(obs_data, target_x, target_y)
                            if path_action:
                                return path_action
                        if closest_enemy.can_attack:
                            return attack(closest_enemy.id)
                    else:
                        # No melee enemies, highest priority enemy as target
                        target_scores = {enemy.id: score_target(enemy) for enemy in obs_data.enemies}
                        # Check if there are same max score targets
                        max_score = max(target_scores.values())
                        max_score_target_ids = [target_id for target_id, score in target_scores.items() if score == max_score]
                        max_score_targets = [enemy for enemy in obs_data.enemies if enemy.id in max_score_target_ids]
                        if len(max_score_targets) > 1:
                            # Chose the closest target
                            best_target = min(max_score_targets, key=lambda x: x.distance)
                        else:
                            best_target = max_score_targets[0]
                        if best_target.can_attack:
                            return attack(best_target.id)
                        else:
                            # Best target is not in shoot range, move to target
                            dx = best_target.position[0]
                            dy = best_target.position[1]
                            if best_target.distance > obs_data.own_shoot_range / obs_data.own_sight_range:
                                path_action = find_path(obs_data, dx, dy, target_type=best_target.unit_type.lower())
                                if path_action:
                                    return path_action
                                elif attack_actions:
                                    if obs_data.last_action in attack_actions:
                                        return obs_data.last_action
                                    attackable_enemies = [enemy for enemy in obs_data.enemies if enemy.can_attack]
                                    closest_enemy = min(
                                        [enemy for enemy in attackable_enemies],
                                        key=lambda enemy: enemy.distance,
                                        default=None,
                                    )
                                    if closest_enemy and closest_enemy.can_attack:
                                        return attack(closest_enemy.id)
                                    return random.choice(attack_actions)
            # If there are no enemies
            else:
                # If there are only allies
                if obs_data.allies:
                    # Check if any melee ally
                    melee_ally = [ally for ally in obs_data.allies if ally.unit_type.lower() in ['zealot', 'zergling', 'baneling']]
                    if melee_ally:
                        # Move to center of melee allies
                        melee_ally_x = sum(ally.position[0] for ally in melee_ally) / len(melee_ally)
                        melee_ally_y = sum(ally.position[1] for ally in melee_ally) / len(melee_ally)
                        dx = melee_ally_x
                        dy = melee_ally_y
                        distance = (dx ** 2 + dy ** 2) ** 0.5
                        if distance > 0.05:
                            path_action = find_path(obs_data, dx, dy)
                            if path_action:
                                return path_action
                    else:
                        # No melee allies, move to target ally
                        ally_x = sum(ally.position[0] for ally in obs_data.allies) / len(obs_data.allies)
                        ally_y = sum(ally.position[1] for ally in obs_data.allies) / len(obs_data.allies)
                        distance = (ally_x ** 2 + ally_y ** 2) ** 0.5
                        if distance > 0.05:
                            dx = ally_x
                            dy = ally_y
                            path_action = find_path(obs_data, dx, dy)
                            if path_action:
                                return path_action
        return default_action()

    return control_logic()
    
__all__ = [
    "default_tactic",
    "find_path",
]