import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.cluster.vq import kmeans2
from sklearn.cluster import DBSCAN

def distance_matrix_AB(A, B):
    assert A.shape[-1] == 2 # assert 2D situation
    assert B.shape[-1] == 2 # assert 2D situation
    n_A_subject = A.shape[-2]
    n_B_subject = B.shape[-2]
    A = np.repeat(np.expand_dims(A,-2), n_B_subject, axis=-2) # =>(64, Na, Nb, 2)
    B = np.repeat(np.expand_dims(B,-2), n_A_subject, axis=-2) # =>(64, Nb, Na, 2)
    Bt = np.swapaxes(B,-2,-3) # =>(64, Na, Nb, 2)
    dis = Bt-A # =>(64, Na, Nb, 2)
    dis = np.linalg.norm(dis, axis=-1)
    return dis

class CheatScriptAI():
    def attackers_policy_1(self, attackers_agent, guard_agents):
        '''

        '''
        guard_agents_alive = [agent for agent in guard_agents if agent.alive]


        for i, agent in enumerate(attackers_agent):
            agent.act = np.zeros(self.dim_p)  ## We'll use this now for Graph NN
            agent.can_fire = False
            agent.local_id = i



      
        eps = 0.5  
        MinPoints = 2  
        num_members = 5  
        min_members =3  


        guards_agent_position = np.zeros((len(guard_agents_alive), 2))
        for i_g, guard_agent in enumerate(guard_agents_alive):
            guards_agent_position[i_g] = guard_agent.pos
        # print(guards_agent_position)


        result = DBSCAN(eps, min_samples = MinPoints).fit(guards_agent_position)

        cluster_results = {}
        for cluster_index, cluster_class in enumerate(result.labels_):
            if cluster_class not in cluster_results.keys():
                cluster_results[cluster_class] = [guards_agent_position[cluster_index]]
                # print(guards_agent_position[cluster_index])
            else:
                cluster_results[cluster_class].append(guards_agent_position[cluster_index])


        cluster_centroid = []
        cluster_radius = []
        for key in cluster_results.keys():
            cluster_index_position = np.array(cluster_results[key])

            if key != -1:
                for i in range(5):
                    try:
                        centroid, label  = kmeans2(cluster_index_position, 1, iter=20, minit='++',seed=np.random.randint(100), missing='raise')
                        break
                    except:
                        pass
                    if i >= 4:
                        centroid, label = kmeans2(cluster_index_position, 1, iter=20, minit='++',seed=np.random.randint(100), missing = 'warn')
                        print('jgk')
                        break
                        # assert False                
                cluster_centroid.append(centroid)
            else:
                # centroid, label = kmeans2(cluster_index_position, len(cluster_results[key]), iter=20, minit='points')
                # centroid = [centroid]
                for item in cluster_index_position:
                    cluster_centroid.append([item])


            dists_to_centroid = np.array(
            [np.linalg.norm(cluster_centroid[-1] - cluster_index_position_pos) for cluster_index_position_pos in cluster_index_position])
            
            cluster_radius.append(max(dists_to_centroid))



        if self.start_flag:

            self.teams_result_step1, self.team_centroid_step1 = self.get_groups(attackers_agent, num_members)
            self.start_flag = False
        # print(teams_result_step1)
        teams_result_step2 = self.teams_result_step1
        team_centroid_step2 = self.team_centroid_step1

        for key in self.teams_result_step1.keys():
            team_index_agents = np.array(self.teams_result_step1[key])
            if len(team_index_agents) <= min_members:
                teams_result_step2, team_centroid_step2 = self.get_groups(attackers_agent, num_members)
                break


        target_cluster_centroid = cluster_centroid
        if len(cluster_centroid) < len(teams_result_step2):
            num_cluster_centroid = len(cluster_centroid)
            if num_cluster_centroid ==0:
                num_cluster_centroid = 1
            temp_num_members = len(attackers_agent) // num_cluster_centroid
            for index in range(len(teams_result_step2) - len(cluster_centroid) ):
                target_cluster_centroid.append(target_cluster_centroid[-1])            
        else:
            target_cluster_centroid = []
            target_cluster_centroid = cluster_centroid[:(len(teams_result_step2))]

        A = team_centroid_step2
        B = np.array(target_cluster_centroid).squeeze()
        dists_to_target_cluster2 = distance_matrix_AB(A, B)

        ri_team_to_cluster, ci_team_to_cluster = linear_sum_assignment(dists_to_target_cluster2)

        self.leader_id = -10
        for key, team_agents in teams_result_step2.items():

            delta_angle = (float)(np.pi / len(team_agents) -1)
            expected_poses_patrol = []
            if key >= len(target_cluster_centroid): 
                key = len(target_cluster_centroid) - 1

            target_cluster = ci_team_to_cluster[key]
            leader_position_patrol = np.array(target_cluster_centroid[target_cluster])  
            circle_radiu = 0.8

            for i, agent in enumerate(team_agents):
                if agent.iden != self.leader_id:
                    expected_poses_patrol.append([leader_position_patrol + circle_radiu * np.array(
                        [np.cos(delta_angle * i), np.sin(delta_angle * i)])])
            dists_patrol = np.array(
                [[np.linalg.norm(np.array([agent.pos[0], agent.pos[1]]) - pos) for pos in expected_poses_patrol]
                 for i, agent in enumerate(team_agents) if agent.iden != self.leader_id])
            ri, ci = linear_sum_assignment(dists_patrol)

            for i, agent in enumerate(team_agents):
                
                if agent.iden == self.leader_id or not agent.alive:
                    continue
                expected_poses_for_it = expected_poses_patrol[ci[i]]
                # relative_value_patrol = expected_poses_for_it - np.array([agent.pos[0], agent.pos[1]])
                # theta_patrol = np.arctan2(relative_value_patrol[0][0][1], relative_value_patrol[0][0][0])
                # if theta_patrol < 0:
                #     theta_patrol += 2 * np.pi

                agent.act[0] = -agent.pos[0] + expected_poses_for_it[0][0][0]
                agent.act[1] = -agent.pos[1] + expected_poses_for_it[0][0][1]

                if self.s_cfg.DISALBE_RED_FUNCTION:
                    agent.act[0] = -agent.act[0]
                    agent.act[1] = -agent.act[1]

            for i, agent in enumerate(team_agents):
                
                if agent.iden == self.leader_id or not agent.alive:
                    continue


        alive_agents_id = [agent.local_id for agent in attackers_agent if agent.alive]
        alive_agents = [agent for agent in attackers_agent if agent.alive]

        agent.local_id
        # if 'distance_matrix' in self.shared_resorce:
        # try:
        # fast and efficient way
        dis = self.shared_resorce['distance_matrix']
        attackers_uids = self.shared_resorce['attackers_uid']
        alive_agents_uids = attackers_uids[alive_agents_id]
        dis2subjects = dis[alive_agents_uids,...]
        dis2guardsagents = dis2subjects[:,self.shared_resorce['guards_uid']]
        nearest_indices = np.argmin(dis2guardsagents, axis=-1)
        for i, agent in enumerate(alive_agents):
            index = nearest_indices[i]
            target = guard_agents[index]
            delta = target.pos - agent.pos
            theta_ = np.arctan2(delta[1], delta[0])
            if theta_ < 0: theta_ += 2 * np.pi
            agent.atk_rad = theta_
            agent.can_fire = True
            if self.s_cfg.DISALBE_RED_FUNCTION:
                agent.can_fire = False



        '''
            self.shared_resorce['distance_matrix'] = self.dis
            self.shared_resorce['guards_uid'] = guards_uid
            self.shared_resorce['attackers_uid'] = attackers_uid
        '''


    def get_groups(self, policy_agents, num_members):



        if num_members == 0:
            num_members = 1
        num_team = len(policy_agents) // num_members
        if num_team == 0:
            num_team = 1
        policy_agents_position = []
        for i, agent in enumerate(policy_agents):
            policy_agents_position.append(agent.pos)
        for i in range(5):
            try:    # its here
                team_centroid, team_labels = kmeans2(policy_agents_position, num_team, iter=20, minit='++',seed=np.random.randint(100), missing = 'raise')
                break
            except:
                pass
            if i >= 4:
                team_centroid, team_labels = kmeans2(policy_agents_position, num_team, iter=20, minit='++',seed=np.random.randint(100), missing = 'warn')
                print('dsad')
                break

        if min(team_labels) != 0:
            assert min(team_labels) == 0
        team_results = {}
        for team_index, team_class in enumerate(team_labels):
            if team_class not in team_results.keys():  
                team_results[team_class] = [policy_agents[team_index]]
            else:
                team_results[team_class].append(policy_agents[team_index])
        return team_results, team_centroid

