import math
from ..utils.distance_api import *
from ..utils.actions_api import *
from ..utils.units_api import *

from ..unit_typeid import UnitTypeId
from scipy.spatial.distance import pdist, squareform
import numpy as np

class DecisionTreeScript():

    def __init__(self, map_name):
        
        self.map_name = map_name
        self.radius = 10 
        self.center = (16, 21)

        

    def script(self, agents, enemies, agent_ability, visible_matrix, iteration):

        actions_list = []

        # Change from dict to list
        agents = [agent for _, agent in agents.items() if agent.health != 0]
        enemies = [enemy for _, enemy in enemies.items() if enemy.health != 0]

        if not agents or not enemies:
            return []

        
        self.banelings = [unit for unit in agents if unit.unit_type==UnitTypeId.BANELING.value]
        self.zerglings = [unit for unit in agents if unit.unit_type==UnitTypeId.ZERGLING.value]

        self.enemy_banelings = [unit for unit in enemies if unit.unit_type==UnitTypeId.BANELING.value]
        self.enemy_zerglings = [unit for unit in enemies if unit.unit_type==UnitTypeId.ZERGLING.value]


        enemy_units = self.enemy_banelings + self.enemy_zerglings

        if not enemy_units:
            return []

        # Baneling attack zerglings.
        for baneling in self.banelings:
            target = self.find_best_attack_target(self.enemy_zerglings)
            if target == None:
                target = self.find_best_attack_target(enemy_units)
            actions_list.append(attack(baneling, target, visible_matrix))

        # Zergling spread out.
        if iteration < 15:
            
            for i, zergling in enumerate(self.zerglings):

                angle = math.pi / len(self.zerglings) * i
                delta_x = math.cos(angle)
                delta_y = math.sin(angle)
                actions_list.append(move(zergling, (delta_x*self.radius + self.center[0], delta_y*self.radius+self.center[1])))
            return actions_list            

        for zergling in self.zerglings:

            target = min(enemy_units, key=lambda eu: distance_to(eu, zergling))
            actions_list.append(attack(zergling, target, visible_matrix))

        return actions_list
    
        
    

    def find_best_attack_target(self, enemies):

        best_target = None
        highest_density = -1

        for enemy in enemies:
            nearby_enemies = closer_than(enemy, enemies, 2)
            if len(nearby_enemies) > highest_density:
                highest_density = len(nearby_enemies)
                best_target = enemy

        return best_target