import math
from ..utils.distance_api import *
from ..utils.actions_api import *
from ..utils.units_api import *

from ..unit_typeid import UnitTypeId
from scipy.spatial.distance import pdist, squareform
import numpy as np

class DecisionTreeScript():

    def __init__(self, map_name):
        
        self.map_name = map_name
        self.radius = 10 
        self.center = (16, 21)

        

    def script(self, agents, enemies, agent_ability, visible_matrix, iteration):

        actions_list = []

        # Change from dict to list
        agents = [agent for _, agent in agents.items() if agent.health != 0]
        enemies = [enemy for _, enemy in enemies.items() if enemy.health != 0]

        if not agents or not enemies:
            return []

        
        self.banelings = [unit for unit in agents if unit.unit_type==UnitTypeId.BANELING.value]
        self.zerglings = [unit for unit in agents if unit.unit_type==UnitTypeId.ZERGLING.value]

        self.enemy_banelings = [unit for unit in enemies if unit.unit_type==UnitTypeId.BANELING.value]
        self.enemy_zerglings = [unit for unit in enemies if unit.unit_type==UnitTypeId.ZERGLING.value]


        enemies = self.enemy_banelings + self.enemy_zerglings

        if not self.enemy_banelings and not self.enemy_zerglings:
            return

        groups = [[], []]
        for i, baneling in enumerate(self.banelings):
            groups[i%2].append(baneling)
        
        for i, zergling in enumerate(self.zerglings):
            groups[i%2].append(zergling)


        for u in groups[0]:
            if self.enemy_zerglings:
                actions_list.append(attack(u, center(self.enemy_zerglings), visible_matrix))
            else:
                actions_list.append(attack(u, center(self.enemy_banelings), visible_matrix))

        if groups[1]:
            if distance_to(center(groups[1]), (16, 8)) < 2:
                for u in groups[1]:
                    target = min(enemies, key=lambda e: distance_to(e, u))
                    actions_list.append(attack(u, target, visible_matrix))

            else:
                for u in groups[1]:
                    actions_list.append(move(u, (16, 8)))
        
        return actions_list