import torch
import numpy as np
import math
from code_ptmc_mappo.envs.starcraft2.smac_maps import get_map_params
import copy

def data_mask(obs, av_action, args):
    agent_obs = args.agent_obs
    n_rollout_threads = args.n_rollout_threads
    n_agents = args.n_agents
    av_action[:, :, 5] = 0
    obs_range_x, obs_range_y = agent_obs
    obs_range_x = 2 * obs_range_x + 1
    obs_range_y = 2 * obs_range_y + 1
    obs_reshaped = obs.reshape(-1, obs_range_x, obs_range_y, 3)
    obs_reshaped[:, :, :, 1] = 0
    obs_reshaped[:, :, :, 2] = 0
    new_obs = obs_reshaped.reshape(n_rollout_threads, n_agents, -1)
    return new_obs, av_action

def reward_reconfiguration(args, agent_pos_t0, agent_pos, obs, tacit_indicator):
    tacit_reward = np.zeros((args.n_rollout_threads, args.n_agents, 1), dtype=np.float64)
    pos_t0 = agent_pos_t0
    pos_t1 = agent_pos
    obs_t0 = obs
    sight_range = args.agent_obs[0]

    for step in range(args.n_rollout_threads):
        agents_obs = obs_t0[step, :, :]
        rs_agents_obs = ag_obs_reshape(agents_obs, args)
        c_pattern = classify_pattern(rs_agents_obs, args)
        for i in range(args.n_agents):
            obs_i = obs_t0[step, i, :]
            rs_obs_i = obs_reshape(obs_i, args)
            agents_pos_t0 = pos_t0[step, :, :]
            agents_pos_t1 = pos_t1[step, :, :]

            if c_pattern[i] == 0:
                lamda_1 = pattern_1or2_membership(i,agents_pos_t0,sight_range)
                tacit_reward[step,i,:] = tacit_2_reward(lamda_1, i,agents_pos_t0,agents_pos_t1)
            elif c_pattern[i] == 1:
                lamda_2 = pattern_1or2_membership(i,agents_pos_t0,sight_range)
                tacit_reward[step, i, :] = tacit_2_reward(lamda_2, i, agents_pos_t0, agents_pos_t1)
            elif c_pattern[i] == 2:
                lamda_3 = pattern_3_membership(i,agents_pos_t0, sight_range)
                tacit_reward[step, i, :] = tacit_3_reward(lamda_3, i, agents_pos_t0, agents_pos_t1)
            elif c_pattern[i] == 3:
                lamda_4 = pattern_4_membership(i, agents_pos_t0, rs_obs_i, sight_range)
                tacit_reward[step, i, :] = tacit_4_reward(lamda_4,i,agents_pos_t0, agents_pos_t1, rs_obs_i, sight_range)
            else:
                print("c_pattern error: ", c_pattern, c_pattern[i])
                break
            tacit_indicator = calculate_tacit(c_pattern[i],tacit_reward[step, i, :], tacit_indicator)
    return tacit_reward, tacit_indicator

# ---------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------
def pattern_1or2_membership(agents_id,agents_pos_0,sight_range):
    d_min, closest_id = min_distance(agents_id,agents_pos_0)
    c_beta = math.sqrt((sight_range + 1) ** 2 + (sight_range + 2) ** 2)
    if d_min >= c_beta:
        lambda_1 = 1
    else:
        lambda_1 = (d_min-(sight_range+1))/(c_beta-(sight_range+1))
    return lambda_1

def pattern_3_membership(agents_id, agents_pos_0, sight_range):
    obs_d_min, closest_id = min_distance(agents_id, agents_pos_0)
    c_alpha = 1.0
    if obs_d_min <= 1:
        lambda_3 = 0
    elif obs_d_min > 1 and obs_d_min <= c_alpha * sight_range:
        lambda_3 = 1 - (obs_d_min - 1) / (c_alpha * sight_range - 1)
    else:
        lambda_3 = 1
    return lambda_3

def pattern_4_membership(agents_id, agents_pos_0, obs_i, sight_range):
    # follower-agents
    indices_pos = np.argwhere(obs_i == 1)
    indice_pos = indices_pos[~np.all(indices_pos == [sight_range, sight_range], axis=1)]
    d_between = math.sqrt((indice_pos[0][0] - sight_range) ** 2 + (indice_pos[0][1] - sight_range) ** 2)
    c_alpha = 1.0
    if d_between <= 1:
        lambda_4f = 1
    elif d_between > 1 and d_between <= c_alpha * sight_range:
        lambda_4f = (d_between - 1) / (c_alpha * sight_range - 1)
    else:
        lambda_4f = 0
    lambda_4f = 1 - lambda_4f

    # leader-agents
    x1, y1 = agents_pos_0[agents_id]
    x2, y2 = agents_pos_0[agents_id] + indice_pos[0] - [sight_range, sight_range]
    j = np.where((agents_pos_0[:, 0] == x2) & (agents_pos_0[:, 1] == y2))[0]
    if len(j) > 0:
        j = j[0]
    if x1 != x2:
        leader_id = agents_id if x1 < x2 else j
    else:
        leader_id = agents_id if y1 < y2 else j
    leader_dmin = outobs_d_min(leader_id, agents_pos_0)
    c_beta = math.sqrt((sight_range + 1) ** 2 + (sight_range + 2) ** 2)
    if leader_dmin >= c_beta:
        lambda_4l = 1
    else:
        lambda_4l = (leader_dmin - (sight_range + 1)) / (c_beta - (sight_range + 1))

    if agents_id == leader_id:
        return lambda_4l
    else:
        return lambda_4f

def tacit_2_reward(lamda_2, agents_id, agents_pos_0, agents_pos_1):
    d0, closest_id = min_distance(agents_id, agents_pos_0)
    agenti_pos = copy.deepcopy(agents_pos_0)
    agenti_pos[agents_id] = agents_pos_1[agents_id]
    d1 = dist_between(agents_id, closest_id, agenti_pos)
    r2_tacit = (d0 - d1) * lamda_2
    return r2_tacit

def tacit_3_reward(lamda_3, agents_id, agents_pos_0, agents_pos_1):
    d0, closest_id = min_distance(agents_id, agents_pos_0)
    agenti_pos = copy.deepcopy(agents_pos_0)
    agenti_pos[agents_id] = agents_pos_1[agents_id]
    d1 = dist_between(agents_id, closest_id, agenti_pos)
    r3_tacit = (d0 - d1) * lamda_3
    return r3_tacit

def tacit_4_reward(lamda_4, agents_id, agents_pos_0, agents_pos_1, obs_i,sight_range):
    indices_pos = np.argwhere(obs_i == 1)
    indice_pos = indices_pos[~np.all(indices_pos == [sight_range, sight_range], axis=1)]
    indice_pos = indice_pos.tolist()[0]
    x1, y1 = agents_pos_0[agents_id]
    x2, y2 = agents_pos_0[agents_id] + indice_pos - [sight_range, sight_range]
    j = np.where((agents_pos_0[:, 0] == x2) & (agents_pos_0[:, 1] == y2))[0]
    if len(j) > 0:
        j = j[0]

    if x1 != x2:
        leader_id = agents_id if x1 < x2 else j
    else:
        leader_id = agents_id if y1 < y2 else j
    follower_id = j if leader_id == agents_id else agents_id

    if agents_id == leader_id:
        # leader-agent
        leader_dmin_0 = outobs_d_min(agents_id, agents_pos_0)
        agenti_pos = copy.deepcopy(agents_pos_0)
        agenti_pos[agents_id] = agents_pos_1[agents_id]
        leader_dmin_1 = outobs_d_min(agents_id, agenti_pos)
        r4_tacit_l = (leader_dmin_0 - leader_dmin_1) * lamda_4
        return r4_tacit_l
    else:
        # follower-agent
        d_between_0 = np.linalg.norm(agents_pos_0[follower_id][0] - agents_pos_0[leader_id][0])
        d_between_1 = np.linalg.norm(agents_pos_1[follower_id][0] - agents_pos_0[leader_id][0])
        r4_tacit_f = (d_between_0 - d_between_1) * lamda_4
        return r4_tacit_f

# ---------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------
def ag_obs_reshape(agents_obs, args):
    obs_range_x, obs_range_y = args.agent_obs
    obs_range_x = 2 * obs_range_x + 1
    obs_range_y = 2 * obs_range_y + 1
    rs_agents_obs = agents_obs.reshape(args.n_agents, obs_range_x, obs_range_y, 3)
    new_agents_obs = rs_agents_obs[:, :, :, 0:1]
    return new_agents_obs

def obs_reshape(obs, args):
    obs_range_x, obs_range_y = args.agent_obs
    obs_range_x = 2 * obs_range_x + 1
    obs_range_y = 2 * obs_range_y + 1
    obs_reshaped = obs.reshape(obs_range_x, obs_range_y, 3)
    new_obs = obs_reshaped[:, :, 0]
    return new_obs

def classify_pattern(agents_obs, args):
    c_pattern = np.zeros(args.n_agents)
    for i in range(args.n_agents):
        agent_obs_i = agents_obs[i, :, :, :]
        num_ones = np.sum(agent_obs_i)
        if num_ones == 1:
            c_pattern[i] = 0
        elif num_ones == 2:
            c_pattern[i] = 3
        else:
            c_pattern[i] = 2
    if not np.all(c_pattern == 0):
        c_pattern[c_pattern == 0] = 1
    c_pattern = c_pattern.astype(int)
    return c_pattern

def min_distance(agents_id, agents_position):
    postion_agents_id = agents_position[agents_id]
    distances = np.linalg.norm(agents_position - postion_agents_id, axis=1)
    distances[agents_id] = np.inf
    closest_id = np.argmin(distances)
    d_min = distances[closest_id]
    return d_min, closest_id

def dist_between(agent_id, closest_id, agents_pos):
    pos_agent_i = agents_pos[agent_id]
    pos_closest_i = agents_pos[closest_id]
    distance = np.linalg.norm(pos_agent_i - pos_closest_i)
    return distance

def outobs_d_min(agents_id, agents_position):
    postion_agents_id = agents_position[agents_id]
    distances = np.linalg.norm(agents_position - postion_agents_id, axis=1)
    valid_distances = distances[distances > 0]
    d_min = valid_distances[np.argsort(valid_distances)[1]]
    return d_min

# ---------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------

def calculate_tacit(pattern_i, tacit_reward_i, indicators_tac):
    pattern_i = int(pattern_i)
    indicators_tac[0, pattern_i] = (tacit_reward_i + indicators_tac[1, pattern_i] * indicators_tac[0, pattern_i]) / (
                indicators_tac[1, pattern_i] + 1)
    indicators_tac[1, pattern_i] += 1
    if tacit_reward_i >= 0:
        indicators_tac[2, pattern_i] += 1
    indicators_tac[3, pattern_i] = indicators_tac[2, pattern_i] / indicators_tac[1, pattern_i]
    return indicators_tac

# ---------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------



