#Reward functions
from numpy import nonzero
import torch

class RewardCalculator():

    def __init__(self, n_arenas):
        self.n_arenas = n_arenas
        self.focus = torch.zeros(n_arenas).long()

    def damage_enemy(self, state, action, next_state):
        damage = self.__damage(state, action, next_state)
        heal = self.__heal(state, action, next_state)
        focus_enemy = (self.focus >= 4)
        reward = torch.zeros(state.shape[0])
        reward[focus_enemy & damage] = 1.0
        reward[focus_enemy & heal] = -5.0
        #enemydistance
        min_enemy_distance = torch.from_numpy(next_state[:, [8,10,12,14]]).min(1)[0]
        return reward.numpy() - 0.01 * min_enemy_distance.numpy()

    def heal_teammate(self, state, action, next_state):
        heal = self.__heal(state, action, next_state)
        damage = self.__damage(state, action, next_state)
        focus_friend = (self.focus >=1) * (self.focus <= 3)
        reward = torch.zeros(state.shape[0])
        reward[focus_friend & heal] = 1.0
        reward[focus_friend & damage] = -5.0
        return reward.float().numpy()

    def kill_enemy(self, state, action, next_state):
        pass

    def update_focus(self, state, action, next_state):
        focuses = torch.tensor([a[4] for a in action])
        self.focus[focuses != 0] = focuses[focuses != 0]
        has_focus_next_state = next_state[:,16]
        focus_lost = (next_state[:,16] == 0.0) & (state[:,16] == 1.0)
        self.focus *= (~focus_lost) #zero if focus was lost
        self.focus *= has_focus_next_state
        self.focus = self.focus.long()

    def __heal(self, state, action, next_state):
        #boolean if healing is going to be inflicted
        focus_distance = self.__get_focus_distance(next_state)
        focus_angle = self.__get_focus_angle(next_state)
        attacks = self.__get_attacks(action)
        healing_ready = torch.from_numpy(state[:,22]).bool()
        distance_threshold = 0.5
        angle_threshold = 0.2
        focus_low_hp = torch.from_numpy(self.__get_focus_hp(next_state) < 1.0).bool()
        heal = (attacks == 3) & (healing_ready) & (focus_distance > 0.06) & (focus_distance < distance_threshold)  & (focus_angle.abs() < angle_threshold) & (focus_low_hp)
        return heal

    def __damage(self, state, action, next_state):
        attacks = self.__get_attacks(action)
        focus_distance = self.__get_focus_distance(next_state)
        focus_angle = self.__get_focus_angle(next_state)
        distance_threshold = 0.072
        angle_threshold = 0.24
        damage = (attacks == 1) * (focus_distance < distance_threshold) * (focus_angle.abs() < angle_threshold)
        return damage
        

    def __get_focus_distance(self, state):
        focus_dist_inds = self.focus.clone()
        focus_dist_inds *= 2
        # focus_dist_inds[self.focus == 1] = 2 #own statue
        # focus_dist_inds[self.focus == 2] = 4 #teammate1
        # focus_dist_inds[self.focus == 3] = 6 #teammate2
        # focus_dist_inds[self.focus == 4] = 8 #enemy statue
        # focus_dist_inds[self.focus == 5] = 10 #enemy1
        # focus_dist_inds[self.focus == 6] = 12 #enemy2
        # focus_dist_inds[self.focus == 7] = 14 #enemy3

        focus_distance = torch.ones(state.shape[0])
        state = torch.from_numpy(state)
        focus_distance[focus_dist_inds != 0] = torch.gather(state[focus_dist_inds != 0], 1, focus_dist_inds[focus_dist_inds != 0].unsqueeze(1) ).squeeze(1)
        return focus_distance

    def __get_focus_angle(self, state):
        focus_angle_ind = self.focus.clone()
        nonzero_focuses = focus_angle_ind != 0
        focus_angle_ind[nonzero_focuses] = focus_angle_ind[nonzero_focuses] * 2 + 1 
        # focus_angle_ind[self.focus == 1] = 3
        # focus_angle_ind[self.focus == 2] = 5
        # focus_angle_ind[self.focus == 3] = 7
        # focus_angle_ind[self.focus == 4] = 9
        # focus_angle_ind[self.focus == 5] = 11
        # focus_angle_ind[self.focus == 6] = 13
        # focus_angle_ind[self.focus == 7] = 15
        focus_angle = torch.ones(state.shape[0])
        state = torch.from_numpy(state)
        focus_angle[focus_angle_ind != 0] = torch.gather(state[focus_angle_ind != 0], 1, focus_angle_ind[focus_angle_ind != 0].unsqueeze(1)).squeeze(1)
        return focus_angle

    def __get_focus_hp(self, state):
        return state[:, 20]

    def __get_attacks(self, action):
        return torch.tensor([a[3] for a in action])


# mdenv rewards
import numpy as np
def within_radius(point, center, radius):
    return (np.sum((point - center)**2) < radius**2).astype(np.float32)

def neg_distance(point, another, radius):
    return -np.linalg.norm(point - another, axis=1)**2
