import math
from ..utils.distance_api import *
from ..utils.actions_api import *
from ..utils.units_api import *

from ..unit_typeid import UnitTypeId
from scipy.spatial.distance import pdist, squareform
import numpy as np


class DecisionTreeScript():

    def __init__(self, map_name):
        self.map_name = map_name
        
        
        
    def script(self, agents, enemies, agent_ability, visible_matrix, iteration):

        agents = [agent for _, agent in agents.items() if agent.health != 0]
        enemies = [enemy for _, enemy in enemies.items() if enemy.health != 0]

        if not agents or not enemies:
            return

        self.banelings = sorted([a for a in agents if a.unit_type==UnitTypeId.BANELING.value], key=lambda a: a.tag)
        self.enemy_zealots = sorted([a for a in enemies if a.unit_type==UnitTypeId.ZEALOT.value], key=lambda a: a.tag)

        actions_list = []

        if iteration <= 10:
            for b in self.banelings:
                actions_list.append(attack(b, (4.5, 4.5), visible_matrix))
            return actions_list

        for b in self.banelings:
            nearest_z = nearest_n_units(b, enemies, 1)[0]
            if not nearest_z:
                continue
            else:
                target_pos = toward(b, nearest_z, -2)
                actions_list.append(move(b, target_pos))

        
        return actions_list