import torch
import numpy as np
import math
from code_ptmc_mappo.envs.starcraft2.smac_maps import get_map_params
from copy import deepcopy

def data_mask(envs,obs,av_action):
    av_action[:, :, 6:] = 0
    om_inl = envs.observation_space[0][1][0] * envs.observation_space[0][1][1]
    om_fin = om_inl + envs.observation_space[0][2][0] * envs.observation_space[0][2][1]
    obs[:, :, om_inl:om_fin] = 0
    return obs,av_action

def reward_reconfiguration(n_rollout_threads, num_agents, envs, st_t0, st_t1, obs_t1, map_name, tacit_indicator, param_alpha):
    tacit_reward = np.zeros((n_rollout_threads, num_agents, 1), dtype=np.float64)
    map_x, map_y, shield_bits_ally = map_info(map_name, st_t1, obs_t1, envs)

    for step in range(n_rollout_threads):
        n_visible = agents_obs_visible(num_agents, envs, st_t0, step)
        c_pattern = classify_pattern(num_agents, n_visible)
        for i in range(num_agents):
            agenti_st_t0 = st_t0[step,i,:]
            if np.all((agenti_st_t0 == 1) | (agenti_st_t0 == 0)):
                continue
            agenti_st_t1 = st_t1[step,i,:]
            if np.all((agenti_st_t1 == 1) | (agenti_st_t1 == 0)):
                continue
            if c_pattern[i] == 0:
                lamda_1 = pattern_1or2_membership(agenti_st_t0, envs)
                tacit_reward[step,i,:] = tacit_1_reward(lamda_1, agenti_st_t0, agenti_st_t1, envs,shield_bits_ally, map_x, map_y)
            elif c_pattern[i] == 1:
                lamda_2 = pattern_1or2_membership(agenti_st_t0, envs)
                tacit_reward[step, i, :] = tacit_2_reward(lamda_2, agenti_st_t0, agenti_st_t1, envs)
            elif c_pattern[i] == 2:
                lamda_3 = pattern_3_membership(agenti_st_t0, envs, param_alpha)
                tacit_reward[step, i, :] = tacit_3_reward(lamda_3, agenti_st_t0, agenti_st_t1, envs, param_alpha)
            elif c_pattern[i] == 3:
                lamda_4 = pattern_4_membership(agenti_st_t0, envs, param_alpha)
                tacit_reward[step, i, :] = tacit_4_reward(lamda_4, agenti_st_t0, agenti_st_t1, envs, param_alpha)
            else:
                print("c_pattern error: ", c_pattern, c_pattern[i])
                break
            tacit_indicator = calculate_tacit(c_pattern[i],tacit_reward[step, i, :], tacit_indicator)
    return tacit_reward, tacit_indicator
# ---------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------
def pattern_1or2_membership(agenti_st_t0, envs):
    d_sight = 1
    count_allies_alive = allies_alive(agenti_st_t0, envs)
    if count_allies_alive != 0 :
        d_min = min_distance(agenti_st_t0, envs)[0]
        if d_min > 1.5 * d_sight:
            lamda_1 = 1
        elif d_min <= 1.5 * d_sight:
            lamda_1 = (d_min - d_sight)/(1.5 * d_sight - d_sight)
        else:
            lamda_1 = 0
        return lamda_1
    elif count_allies_alive == 0 :
        lamda_1 = 1
        return lamda_1

def pattern_3_membership(agenti_st_t0, envs, param_alpha):
    d_min = min_distance(agenti_st_t0, envs)[0]
    d_max = max_distance(agenti_st_t0, envs)[0]
    c_alpha = param_alpha
    d0 = c_alpha * 1
    if d_min < d0:
        lamda_3 = 1 - (d_min - 0) / (d0 - 0)
    elif 1 > d_min >= d0:
        lamda_3 = 1 - (1 - d_max) / (1 - d0)
    else:
        print("d_min, d_max: ", d_min, d_max)
    return lamda_3

def pattern_4_membership(agenti_st_t0, envs,param_alpha):
    d_min, ally_agent_id = min_distance(agenti_st_t0, envs)
    agent_type = distinguish_agents(agenti_st_t0, envs, ally_agent_id)
    count_allies_alive = allies_alive(agenti_st_t0, envs)
    d_min_non, out_agent_id = non_min_distance(agenti_st_t0, envs)
    if count_allies_alive != 1:
        c_beta = 1.5
        if d_min_non <= c_beta * 1:
            lamda_4_lead = (d_min_non - 1) / (c_beta * 1 - 1)
        elif d_min_non > c_beta * 1:
            lamda_4_lead = 1
    elif count_allies_alive == 1:
        agent_type = 0

    c_alpha = param_alpha
    if d_min < c_alpha * 1:
        lamda_4_follow = (d_min - 0) / (c_alpha * 1 - 0)
        lamda_4_follow = 1 - lamda_4_follow
    elif 1 > d_min >= c_alpha * 1:
        lamda_4_follow = (1 - d_min) / (1 - c_alpha * 1)
        lamda_4_follow = 1 - lamda_4_follow
    else:
        print("d_min: ", d_min)

    if agent_type == 1:
        lamda_4 = lamda_4_lead
    elif agent_type == 0:
        lamda_4 = lamda_4_follow
    return lamda_4

def tacit_1_reward(lamda_1, agenti_st_t0, agenti_st_t1, envs, shield_bits_ally, map_x, map_y):
    own_feats_x_size = envs.share_observation_space[0][0] - envs.share_observation_space[0][4][0] * \
                       envs.share_observation_space[0][4][1] + 5 + shield_bits_ally
    own_feats_y_size = envs.share_observation_space[0][0] - envs.share_observation_space[0][4][0] * \
                       envs.share_observation_space[0][4][1] + 6 + shield_bits_ally
    sight_range = 9
    own_feats_x_t0 = agenti_st_t0[own_feats_x_size] * map_x / sight_range
    own_feats_y_t0 = agenti_st_t0[own_feats_y_size] * map_y / sight_range
    d0 = math.sqrt(own_feats_x_t0 ** 2 + own_feats_y_t0 ** 2)
    own_feats_x_t1 = agenti_st_t1[own_feats_x_size] * map_x / sight_range
    own_feats_y_t1 = agenti_st_t1[own_feats_y_size] * map_y / sight_range
    d1 = math.sqrt(own_feats_x_t1 ** 2 + own_feats_y_t1 ** 2)
    r1_tacit = (d0 - d1) * lamda_1
    return r1_tacit

def tacit_2_reward(lamda_2, agenti_st_t0, agenti_st_t1, envs):
    d0, ally_agent_id = min_distance(agenti_st_t0, envs)
    d1 = dist_between(ally_agent_id, agenti_st_t1, envs)
    r2_tacit = (d0 - d1) * lamda_2
    return r2_tacit

def tacit_3_reward(lamda_3, agenti_st_t0, agenti_st_t1, envs, param_alpha):
    d_min_t0, min_ally_id = min_distance(agenti_st_t0, envs)
    d_min_t1 = dist_between(min_ally_id, agenti_st_t1, envs)
    r_min = d_min_t1 - d_min_t0
    d_max_t0, max_ally_id = max_distance(agenti_st_t0, envs)
    d_max_t1 = dist_between(max_ally_id, agenti_st_t1, envs)
    r_max = d_max_t0 - d_max_t1
    c_alpha = param_alpha
    d0 = c_alpha * 1
    if d_min_t0 < d0:
        r3_tacit = lamda_3 * r_min
    elif d_min_t0 >= d0 and d_min_t0 <= 1:
        r3_tacit = lamda_3 * r_max
    else:
        print("d_min_t0, d_max_t0: ", d_min_t0, d_max_t0)
    return r3_tacit

def tacit_4_reward(lamda_4, agenti_st_t0, agenti_st_t1, envs, param_alpha):
    df_t0, ally_agent_id = min_distance(agenti_st_t0, envs)
    agent_type = distinguish_agents(agenti_st_t0, envs, ally_agent_id)
    count_allies_alive = allies_alive(agenti_st_t0, envs)
    dl_t0, out_agent_id = non_min_distance(agenti_st_t0, envs)
    if count_allies_alive != 1:
        dl_t1 = dist_between(out_agent_id, agenti_st_t1, envs)
        r4_tacit_lead = (dl_t0 - dl_t1) * lamda_4
    elif count_allies_alive == 1:
        agent_type = 0

    df_t1 = dist_between(ally_agent_id, agenti_st_t1, envs)
    if df_t1 == 0:
        r4_tacit = 0
        return r4_tacit
    else:
        c_alpha = param_alpha
        if df_t0 < c_alpha * 1 :
            r4_tacit_follow = (df_t1 - df_t0) * lamda_4
        elif df_t0 >= c_alpha * 1 and df_t0 <= 1:
            r4_tacit_follow = (df_t0 - df_t1) * lamda_4

        if agent_type == 1:
            r4_tacit = r4_tacit_lead
        elif agent_type == 0:
            r4_tacit = r4_tacit_follow
        return r4_tacit
# ---------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------
def map_info(map_name, st_t1, obs_t1, envs):
    map_params = get_map_params(map_name)
    agent_race = map_params["a_race"]
    shield_bits_ally = 1 if agent_race == "P" else 0
    map_x, map_y = 32, 32
    return map_x, map_y, shield_bits_ally

def agents_obs_visible(num_agents, envs, st_t0, step):
    n_visible = np.zeros(num_agents)
    n_ally_feats = envs.share_observation_space[0][1][1]
    for i in range(num_agents):
        visible = np.zeros(num_agents-1)
        agenti_st_t0 = st_t0[step, i, :]
        for j in range(num_agents-1):
            if agenti_st_t0[j * n_ally_feats] == 0:
                visible[j] = 0
            elif agenti_st_t0[j * n_ally_feats] == 1:
                visible[j] = 1
        if n_visible[i] == 0:
            if np.sum(visible) == 0:
                n_visible[i] = 0
            elif np.sum(visible) == 1:
                n_visible[i] = 1
            elif np.sum(visible) > 1:
                n_visible[i] = 2
                for j in range(num_agents-1):
                    if visible[j] == 1:
                        if j < i :
                            n_visible[j] = 2
                        elif j >= i :
                            n_visible[j+1] = 2
    return n_visible

def classify_pattern(num_agents, n_visible):
    c_pattern = np.zeros(num_agents)
    if np.sum(n_visible) == 0:
        return c_pattern
    else:
        for i in range(num_agents):
            if n_visible[i] == 0:
                c_pattern[i] = 1
            elif n_visible[i] == 1:
                c_pattern[i] = 3
            elif n_visible[i] == 2:
                c_pattern[i] = 2
        return c_pattern

def allies_alive(agenti_st_t0, envs):
    n_allies = envs.share_observation_space[0][1][0]
    ally_feats = envs.share_observation_space[0][1][1]
    count_allies_alive = 0
    for i in range(n_allies):
        if agenti_st_t0[i * ally_feats + 1] != 0:
            count_allies_alive += 1
    return count_allies_alive

def min_distance(agenti_st_t0, envs):
    n_allies = envs.share_observation_space[0][1][0]
    ally_feats = envs.share_observation_space[0][1][1]
    dist = np.inf
    for i in range(n_allies):
        if agenti_st_t0[i*ally_feats+1] < dist and agenti_st_t0[i*ally_feats+1] !=0 :
            dist = agenti_st_t0[i*ally_feats+1]
            ally_agent_id = i
    try:
        ally_agent_id
    except NameError:
        ally_agent_id = None
    return dist, ally_agent_id

def dist_between(ally_agent_id, st_agent, envs):
    n_ally_feats = envs.share_observation_space[0][1][1]
    dist = st_agent[n_ally_feats * ally_agent_id + 1]
    return dist

def max_distance(agenti_st_t0, envs):
    n_ally_feats = envs.share_observation_space[0][1][1]
    n_allies = envs.share_observation_space[0][1][0]
    dist = -np.inf
    for i in range(n_allies):
        if agenti_st_t0[i * n_ally_feats] == 1:
            if agenti_st_t0[i * n_ally_feats + 1] > dist:
                dist = agenti_st_t0[i * n_ally_feats + 1]
                ally_agent_id = i
    try:
        ally_agent_id
    except NameError:
        ally_agent_id = None
    return dist, ally_agent_id

def distinguish_agents(agenti_st_t0, envs, ally_agent_id):
    n_ally_feats = envs.share_observation_space[0][1][1]
    rela_x = agenti_st_t0[ally_agent_id * n_ally_feats + 2]
    rela_y = agenti_st_t0[ally_agent_id * n_ally_feats + 3]
    if rela_x > 0:
        agent_type = 0  # follow_agent
    elif rela_x == 0 and rela_y > 0:
        agent_type = 0  # follow_agent
    else:
        agent_type = 1  # leader_agent
    return agent_type

def non_min_distance(agenti_st_t0, envs):
    n_allies = envs.share_observation_space[0][1][0]
    ally_feats = envs.share_observation_space[0][1][1]
    dist = np.inf
    death_count = np.zeros(n_allies)
    for i in range(n_allies):
        if agenti_st_t0[i*ally_feats] == 0:
            if agenti_st_t0[i*ally_feats+1] < dist and agenti_st_t0[i*ally_feats+1] != 0:
                death_count[i] = 1
                dist = agenti_st_t0[i*ally_feats+1]
                non_ally_agent_id = i
    if np.all(death_count == 0):
        dist = 0
        non_ally_agent_id = n_allies + 2
    return dist, non_ally_agent_id
# ---------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------

def calculate_tacit(pattern_i, tacit_reward_i, indicators_tac):
    pattern_i = int(pattern_i)
    indicators_tac[0, pattern_i] = (tacit_reward_i + indicators_tac[1, pattern_i] * indicators_tac[0, pattern_i]) / (
                indicators_tac[1, pattern_i] + 1)
    indicators_tac[1, pattern_i] += 1
    if tacit_reward_i >= 0:
        indicators_tac[2, pattern_i] += 1
    indicators_tac[3, pattern_i] = indicators_tac[2, pattern_i] / indicators_tac[1, pattern_i]
    return indicators_tac

# ---------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------



