import numpy as np
import matplotlib.pyplot as plt
import torch
import math

def data_mask(envs, state, obs, av_action):
    avail_actions = np.array(av_action)
    avail_actions[:, 6:] = 0
    avail_actions = avail_actions.tolist()

    # mask opponents'state
    ally_num = envs.get_ally_num_attributes() * envs.n_agents
    enemy_num = envs.get_enemy_num_attributes() * envs.n_enemies
    state[ally_num: ally_num + enemy_num] = 0

    # mask opponents'obs
    enemy_dim = np.prod(envs.get_obs_enemy_feats_size())
    om_inl = envs.get_obs_move_feats_size()
    om_fin = envs.get_obs_move_feats_size() + enemy_dim
    obs_array = np.stack(obs)
    obs_array[:, om_inl:om_fin] = 0
    obs = list(obs_array)

    return state, obs, avail_actions

class RewardRecon:
    def __init__(self, env, batch, t, obs_t1, state_t1, args):
        self.envs = env
        self.env_info = self.envs.get_env_info()
        self.batch = batch
        self.t = t
        self.n_agents = self.env_info["n_agents"]
        self.n_enemies = self.envs.n_enemies
        self.reward_matrix = np.zeros((1,self.n_agents))
        self.obs_t0_matrix = self.batch['obs'][:,self.t]
        self.obs_t1_matrix = obs_t1
        self.state_t0_matrix = self.batch['state'][:,self.t]
        self.state_t0_matrix = self.state_t0_matrix.cpu().numpy().squeeze()
        self.state_t1_matrix = state_t1
        self.obs_ally_inl = self.envs.get_obs_move_feats_size() + np.prod(self.envs.get_obs_enemy_feats_size())
        self.n_allies = self.envs.get_obs_ally_feats_size()[0]
        self.ally_dim = self.envs.get_obs_ally_feats_size()[1]
        self.map_x, self.map_y = 32, 32
        self.param_alpha = args.param_alpha

    def tacit_reward(self,nt,nr,sr, tac):
        c_pattern = self.classify_pattern()

        for k in range(self.n_agents):
            obs_k_t0 = self.obs_t0_matrix[:,k,:]
            obs_k_t0 = obs_k_t0.cpu().numpy().squeeze()
            obs_k_t1 = self.obs_t1_matrix[k]
            if np.all(obs_k_t0 == 0):
                continue
            if np.all(obs_k_t1 == 0):
                continue

            if c_pattern[k] == 0:
                lamda_1 = self.pattern_1or2_membership(k, obs_k_t0)
                self.tacit_1_reward(k, lamda_1)
            elif c_pattern[k] == 1:
                lamda_2 = self.pattern_1or2_membership(k, obs_k_t0)
                self.tacit_2_reward(k, lamda_2, obs_k_t0, obs_k_t1)
            elif c_pattern[k] == 2:
                lamda_3 = self.pattern_3_membership(obs_k_t0)
                self.tacit_3_reward(lamda_3, k, obs_k_t0, obs_k_t1)
            elif c_pattern[k] == 3:
                lamda_4 = self.pattern_4_membership(k, obs_k_t0)
                self.tacit_4_reward(lamda_4, k, obs_k_t0, obs_k_t1)

            nt, nr, sr, tac = self.calculate_tacit(k, c_pattern[k], nt, nr, sr, tac)
        return self.reward_matrix, nt, nr, sr

 # -----------------------------------------------------------------------------------------------
    def pattern_1or2_membership(self, k, obs_k):
        d_sight = 1
        count_allies_alive = self.allies_alive()
        if count_allies_alive != 0:
            d_min = self.non_min_distance(k, obs_k)[0]
            if d_min > 1.5 * d_sight:
                lamda_1 = 1
            elif d_min <= 1.5 * d_sight:
                lamda_1 = (d_min - d_sight) / (1.5 * d_sight - d_sight)
            else:
                lamda_1 = 0
            return lamda_1
        elif count_allies_alive == 0:
            lamda_1 = 1
            return lamda_1

    def pattern_3_membership(self,obs_k):
        d_min = self.min_distance(obs_k)[0]
        d_max = self.max_distance(obs_k)[0]
        c_alpha = self.param_alpha
        d0 = c_alpha * 1
        if d_min < d0:
            lamda_3 = 1 - (d_min - 0) / (d0 - 0)
        elif 1 > d_min >= d0:
            lamda_3 = 1 - (1 - d_max) / (1 - d0)
        else:
            print("d_min, d_max: ", d_min, d_max)
        return lamda_3

    def pattern_4_membership(self, k, obs_k):
        d_min, ally_agent_id = self.min_distance(obs_k)
        agent_type = self.distinguish_agents(obs_k, ally_agent_id)
        count_allies_alive = self.allies_alive()
        d_min_non = self.non_min_distance(k, obs_k)[0]
        if count_allies_alive != 1:
            c_beta = 1.5
            if d_min_non <= c_beta * 1:
                lamda_4_lead = (d_min_non - 1) / (c_beta * 1 - 1)
            elif d_min_non > c_beta * 1:
                lamda_4_lead = 1
        elif count_allies_alive == 1:
            agent_type = 0
        c_alpha = self.param_alpha
        if d_min < c_alpha * 1:
            lamda_4_follow = (d_min - 0) / (c_alpha * 1 - 0)
            lamda_4_follow = 1 - lamda_4_follow
        elif 1 > d_min >= c_alpha * 1:
            lamda_4_follow = (1 - d_min) / (1 - c_alpha * 1)
            lamda_4_follow = 1 - lamda_4_follow
        else:
            print("d_min: ", d_min)

        if agent_type == 1:
            lamda_4 = lamda_4_lead
        elif agent_type == 0:
            lamda_4 = lamda_4_follow
        return lamda_4

    def tacit_1_reward(self, k, lamda_1):
        nf_al = self.envs.get_ally_num_attributes()
        sight_range = 9
        own_feats_x_t0 = self.state_t0_matrix[nf_al * k+2] * self.map_x / sight_range
        own_feats_y_t0 = self.state_t0_matrix[nf_al * k+3] * self.map_y / sight_range
        d0 = math.sqrt(own_feats_x_t0 ** 2 + own_feats_y_t0 ** 2)
        own_feats_x_t1 = self.state_t1_matrix[nf_al * k+2] * self.map_x / sight_range
        own_feats_y_t1 = self.state_t1_matrix[nf_al * k+2] * self.map_y / sight_range
        d1 = math.sqrt(own_feats_x_t1 ** 2 + own_feats_y_t1 ** 2)
        r1_tacit = (d0 - d1) * lamda_1
        self.reward_matrix[:, k] = np.array([r1_tacit], dtype=np.float32)
        return r1_tacit

    def tacit_2_reward(self, k, lamda_2, obs_k_t0, obs_k_t1):
        d0, d1 = self.non_min_distance(k, obs_k_t0)
        r2_tacit = (d0 - d1) * lamda_2
        self.reward_matrix[:, k] = np.array([r2_tacit], dtype=np.float32)
        return r2_tacit

    def tacit_3_reward(self, lamda_3, k, obs_k_t0, obs_k_t1):
        d_min_t0, min_ally_id = self.min_distance(obs_k_t0)
        d_min_t1 = obs_k_t1[self.obs_ally_inl + self.ally_dim * min_ally_id + 1]
        r_min = d_min_t1 - d_min_t0
        d_max_t0, max_ally_id = self.max_distance(obs_k_t0)
        d_max_t1 = obs_k_t1[self.obs_ally_inl + self.ally_dim * max_ally_id + 1]
        r_max = d_max_t0 - d_max_t1
        c_alpha = self.param_alpha
        d0 = c_alpha * 1
        if d_min_t0 < d0:
            r3_tacit = lamda_3 * r_min
        elif d_min_t0 >= d0 and d_min_t0 <= 1:
            r3_tacit = lamda_3 * r_max
        else:
            print("d_min_t0, d_max_t0: ", d_min_t0, d_max_t0)
        self.reward_matrix[:, k] = np.array([r3_tacit], dtype=np.float32)
        return r3_tacit

    def tacit_4_reward(self, lamda_4, k, obs_k_t0, obs_k_t1):
        df_t0, ally_agent_id = self.min_distance(obs_k_t0)
        agent_type = self.distinguish_agents(obs_k_t0, ally_agent_id)
        count_allies_alive = self.allies_alive()
        dl_t0, dl_t1 = self.non_min_distance(k, obs_k_t0)
        if count_allies_alive != 1:
            r4_tacit_lead = (dl_t0 - dl_t1) * lamda_4
        elif count_allies_alive == 1:
            agent_type = 0

        df_t1 = obs_k_t1[self.obs_ally_inl + self.ally_dim * ally_agent_id + 1]
        if df_t1 == 0:
            r4_tacit = 0
            return r4_tacit
        else:
            c_alpha = self.param_alpha
            if df_t0 < c_alpha * 1:
                r4_tacit_follow = (df_t1 - df_t0) * lamda_4
            elif df_t0 >= c_alpha * 1 and df_t0 <= 1:
                r4_tacit_follow = (df_t0 - df_t1) * lamda_4

            if agent_type == 1:
                r4_tacit = r4_tacit_lead
            elif agent_type == 0:
                r4_tacit = r4_tacit_follow
        self.reward_matrix[:, k] = np.array([r4_tacit], dtype=np.float32)
        return r4_tacit

    def calculate_tacit(self, k, lamda_id, nt, nr, sr, tac):
        lamda_id = int(lamda_id)
        # nt--num_all
        nt[lamda_id] += 1
        # sr--mean_r_tac
        sr[lamda_id] = (sr[lamda_id] * nt[lamda_id] + self.reward_matrix[:, k]) / (nt[lamda_id] + 1)
        if self.reward_matrix[:, k] >= 0:
            # nr--num_postive
            nr[lamda_id] += 1
        tac[lamda_id] = nr[lamda_id] / nt[lamda_id]
        return nt, nr, sr, tac
# ------------------------------------------------------------------------------------------------------------
    def classify_pattern(self):
        n_visible = np.zeros(self.n_agents)
        c_pattern = np.zeros(self.n_agents)
        for i in range(self.n_agents):
            obs_i = self.obs_t0_matrix[:,i,:]
            visible = np.zeros(self.n_allies)
            for j in range(self.n_allies):
                if obs_i[:,self.obs_ally_inl + self.ally_dim * j] == 0:
                    visible[j] = 0
                elif obs_i[:,self.obs_ally_inl + self.ally_dim * j] == 1:
                    visible[j] = 1
            if n_visible[i] == 0:
                if np.sum(visible) == 0:
                    n_visible[i] = 0
                elif np.sum(visible) == 1:
                    n_visible[i] = 1
                elif np.sum(visible) > 1:
                    n_visible[i] = 2
                    for j in range(self.n_agents - 1):
                        if visible[j] == 1:
                            if j < i:
                                n_visible[j] = 2
                            elif j >= i:
                                n_visible[j + 1] = 2

        if np.sum(n_visible) == 0:
            return c_pattern
        else:
            for i in range(self.n_agents):
                if n_visible[i] == 0:
                    c_pattern[i] = 1
                elif n_visible[i] == 1:
                    c_pattern[i] = 3
                elif n_visible[i] == 2:
                    c_pattern[i] = 2
            return c_pattern

    def allies_alive(self):
        nf_al = self.envs.get_ally_num_attributes()
        count_allies_alive = 0
        for i in range(self.n_allies):
            if self.state_t0_matrix[i * nf_al] != 0:
                count_allies_alive += 1
        return count_allies_alive

    def min_distance(self,obs_k):
        dist = np.inf
        for i in range(self.n_allies):
            if (obs_k[self.obs_ally_inl + i * self.ally_dim + 1] < dist
                    and obs_k[self.obs_ally_inl + i * self.ally_dim] == 1):
                dist = obs_k[self.obs_ally_inl + i * self.ally_dim + 1]
                ally_agent_id = i
        try:
            ally_agent_id
        except NameError:
            ally_agent_id = None
        return dist, ally_agent_id

    def max_distance(self, obs_k):
        dist = -np.inf
        for i in range(self.n_allies):
            if obs_k[self.obs_ally_inl + i * self.ally_dim] == 1:
                if obs_k[self.obs_ally_inl + i * self.ally_dim +1] > dist:
                    dist = obs_k[self.obs_ally_inl + i * self.ally_dim +1]
                    ally_agent_id = i
        try:
            ally_agent_id
        except NameError:
            ally_agent_id = None
        return dist, ally_agent_id

    def distinguish_agents(self, obs_k, ally_agent_id):
        rela_x = obs_k[self.obs_ally_inl + ally_agent_id * self.ally_dim +2]
        rela_y = obs_k[self.obs_ally_inl + ally_agent_id * self.ally_dim +3]
        if rela_x > 0:
            agent_type = 0  # follow_agent
        elif rela_x == 0 and rela_y > 0:
            agent_type = 0  # follow_agent
        else:
            agent_type = 1  # leader_agent
        return agent_type

    def non_min_distance(self, k, obs_k):
        dist = np.inf
        nf_al = self.envs.get_ally_num_attributes()
        sight_range = 9
        x_k = self.state_t0_matrix[k * nf_al + 2] * self.map_x / sight_range
        y_k = self.state_t0_matrix[k * nf_al + 3] * self.map_y / sight_range
        for i in range(self.n_allies):
            if obs_k[self.obs_ally_inl + i * self.ally_dim] == 0:
                if i < k:
                    x_i = self.state_t0_matrix[i * nf_al + 2] * self.map_x / sight_range
                    y_i = self.state_t0_matrix[i * nf_al + 3] * self.map_y / sight_range
                    agent_id_temp = i
                elif i >= k:
                    x_i = self.state_t0_matrix[(i+1) * nf_al + 2] * self.map_x / sight_range
                    y_i = self.state_t0_matrix[(i+1) * nf_al + 3] * self.map_y / sight_range
                    agent_id_temp = i + 1
                dist_temp = math.sqrt((x_i - x_k) ** 2 + (y_i - y_k) ** 2)
                if dist_temp < dist:
                    dist = dist_temp
                    agent_id = agent_id_temp

        x_k_t1 = self.state_t1_matrix[k * nf_al + 2] * self.map_x / sight_range
        y_k_t1 = self.state_t1_matrix[k * nf_al + 3] * self.map_y / sight_range
        x_i_t0 = self.state_t0_matrix[agent_id  * nf_al + 2] * self.map_x / sight_range
        y_i_t0 = self.state_t0_matrix[agent_id  * nf_al + 3] * self.map_y / sight_range
        dist_t1 = math.sqrt((x_i_t0 - x_k_t1) ** 2 + (y_i_t0 - y_k_t1) ** 2)
        return dist, dist_t1