#utils
import gym
from gym import Env
from gym.spaces import MultiDiscrete
import torch
import numpy as np
import asyncio
import scipy.spatial.distance

#Configs
home_team = [
    { 'primaryColor': '#FF0000', "slots": ['Talons', None, "HealingGland"]},
    { 'primaryColor': '#00FFFF', "slots": ['Talons', None, "HealingGland"]},
    { 'primaryColor': '#00FFFF', "slots": ['Talons', None, "HealingGland"]}
]

away_team = [
    { 'primaryColor': '#FF00FF', "slots": ['Talons', None, "HealingGland"]},
    { 'primaryColor': '#FF00FF', "slots": ['Talons', None, "HealingGland"]},
    { 'primaryColor': '#FF00FF', "slots": ['Talons', None, "HealingGland"]}
]


class SingleAgentWrapper(gym.Wrapper):
    def __init__(self, env: Env, bot_policy) -> None:
        self.env = env
        self.n_arenas = self.env.n_teams // 2
        #shape is [n_agents_per_team * n_arena * n_teams(home_first), agent, features]
        self.homes_first_agent_slice = slice(0, self.n_arenas * self.env.n_agents_per_team, self.env.n_agents_per_team)
        self.bot_policy = bot_policy
        self.act_list = [(2, 2, 0, 0, 0) for _ in range(self.env.n_agents)]

        super().__init__(env)

    def step(self, action):
        
        self.act_list[self.homes_first_agent_slice] = action
        obs, reward, done, info = self.env.step(self.act_list)

        self.generate_next_bot_actions(obs)

        return obs[self.homes_first_agent_slice], reward[self.homes_first_agent_slice], done[self.homes_first_agent_slice], info

    def reset(self):
        obs = self.env.reset()
        self.generate_next_bot_actions(obs)
        return obs[self.homes_first_agent_slice]

    def generate_next_bot_actions(self, state):
        #save the next action for the bots (and agent, but they get replaced)
        self.act_list = self.bot_policy.get_action(state)


class DiscreteActionEnv(gym.Wrapper):
    def __init__(self, env: Env, discretization) -> None:
        self.env = env
        super().__init__(env)
        n_actions = [d.shape[0] if d is not None else self.env.action_space[di].n for di, d in enumerate(discretization)]
        self.action_space = MultiDiscrete(n_actions)
        self.action_mapping = self.build_action_map(discretization)

    def step(self, action):
        a_list = []
        for act in action:
            continuous_action = tuple(self.action_mapping[ai][a] for ai, a in enumerate(act))
            a_list.append(continuous_action)
        action = a_list
        obs, rew, done, info = self.env.step(action)
        return obs, rew, done, info

    def build_action_map(self, discretization):
        """
        Build correspondence discrete => continuous action
        """
        d_map = []
        for di, d in enumerate(discretization):
            if d is not None:
                d_map.append( {i: (interv[0].item() + interv[1].item()) / 2.0 for i, interv in enumerate(d)} )
            else:
                d_map.append( {i:i for i in range(self.env.action_space[di].n)})

        return d_map
                

class BotPolicy():

    def __init__(self, env, bot_net=None, weights= None):
        self.env = env
        self.dist_threshold = 0.05
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.set_weights(weights)
        if bot_net is None:
            self.get_action = self.heuristic
        else:
            self.bot_net = bot_net
            self.get_action = self.net_action

    def set_weights(self, weights):
        weights = torch.from_numpy(weights).float()
        self.weights = weights.to(self.device)

    def net_action(self, state):
        with torch.no_grad():
            state = torch.from_numpy(state).to(self.device)
            state = torch.cat((state, self.weights.expand(state.shape[0], -1)), dim=1)
            dists = self.bot_net(state)

            actions = [dist.sample() for dist in dists]
            actions = torch.stack(actions, dim=1)
        
        return [tuple(a.tolist()) for a in actions]

        

    def heuristic(self, state):
        N = state.shape[0]
        dists = state[:, (8, 10, 12, 14)]
        argmindist = np.argmin(dists, axis=-1)
        mindist = np.take_along_axis(dists, argmindist.reshape((-1, 1)), 1)
        enemy_ind = np.int32(argmindist + 4)
        has_melee_arms = np.any(state[:, (32,33,34, 35)] == 1.0, axis=1)
        has_range_arms = np.any(state[:, (40, 41)] == 1.0, axis=1)

        actions = []
        for ii in range(N):
            if mindist[ii] > self.dist_threshold and not has_range_arms[ii]:
                #move towards enemy
                actions.append( (self.env.action_space.nvec[0]//2, self.env.action_space.nvec[1]//2, self.env.action_space.nvec[2]-1, 0, enemy_ind[ii].item()) )
            elif has_range_arms[ii]:
                #shoot enemy
                actions.append( (self.env.action_space.nvec[0]//2, self.env.action_space.nvec[1]//2, 0, 1, enemy_ind[ii].item()) )
            elif has_melee_arms[ii]:
                #attack closest enemy
                actions.append( (0, 0, self.env.action_space.nvec[2]-1, 1, enemy_ind[ii].item()) )
            else:
                #stand still
                actions.append( (0., 0., 0., 0, 0) )

        return actions


def set_reward_function(weights, env):
    weights = np.repeat(weights, [3 for _ in range(weights.shape[0])], axis=0)

    home_configs = []
    config_ind = 0
    for w in weights:
        home_cp = home_team[config_ind].copy()
        reward_fn = {
            "killEnemyStatue": 0,
            "killEnemyUnit": 0,
            "damageEnemyUnit": w[0].item() / 1.0 * 1,
            "healTeammate1": w[1].item() / 2.0 * 1,
            "healTeammate2": w[1].item() / 2.0 * 1,
            "friendlyFire": -2,
            "statueDamageTaken": -0.1,
            "fallDamageTaken": -60,
            "healEnemy": -2,
            "timeScaling": 0
        }
        home_cp["rewardFunction"] = reward_fn
        home_configs.append(home_cp)
        config_ind += 1
        if config_ind==3:
            config_ind=0
        

    asyncio.get_event_loop().run_until_complete(env.app.update_home_team_config(home_configs))
    asyncio.get_event_loop().run_until_complete(env.app.update_away_team_config(away_team))


def cari_concat(state, weights):
    state = np.concatenate((state, weights), axis=1)
    return state

def get_goals(n_goals, random_locs=False, random_rad=False):
    #sample goal positions from polar coordinates

    if not random_locs:
        thetas = np.linspace(0, 2*np.pi, n_goals+1)[:-1]
        loc_rads = 0.5 * np.ones(n_goals)
    else:
        thetas = np.random.uniform(0, 2*np.pi, n_goals)
        loc_rads = np.random.uniform(0.2, 0.8, n_goals)

    rads = np.random.uniform(0.1, 0.5, n_goals) if random_rad else 0.1*np.ones(n_goals)

    xs = loc_rads * np.cos(thetas)
    ys = loc_rads * np.sin(thetas)

    return np.stack((xs, ys), axis=1), rads.reshape((-1, 1))


def get_hypersphere_locations(n_dim, n_goals, n_trials=1000):
    #from http://extremelearning.com.au/how-to-generate-uniformly-random-points-on-n-spheres-and-n-balls/
    #sample n_goals from n_dim dimensional hypersphere N_trials times, and choose the one where min distance and max distance is the highest
    
    #for reproducibility across scripts
    backup_state = np.random.get_state()
    np.random.seed(123)
    
    u = np.random.normal(0,1,(n_trials, n_dim, n_goals))  
    norms = np.linalg.norm(u, axis=1)
    u = u / norms.reshape((n_trials, 1, n_goals))

    mm_dists = np.zeros(n_trials) + 1e10
    for ii, goals in enumerate(u):
        p_dist = scipy.spatial.distance.pdist(goals.T, metric="euclidean")
        mm_dists[ii] = np.min(p_dist)

    goals = u[np.argmax(mm_dists)].T
    goals = 0.5 * goals
    rads = 0.1*np.ones(n_goals)
    np.random.set_state(backup_state)
    return goals, rads


def flatten_buffer(buffer, to_end = 0):
    """
     Flattens a buffer.
     That is buffer.states can be list of N_ts tensors each having shape of (N_arenas, state),
     this transforms them into list of N_ts * N_arenas tensors of shape (1, state)
     to_end is for addition, when a buffer already includes some flattened transitions, and unflattened are added at the end
    """
    #save the transitions to be flattened temporalily
    states = buffer.states[to_end:] 
    actions = buffer.actions[to_end:] 
    logprobs = buffer.logprobs[to_end:] 
    rewards = buffer.rewards[to_end:] 
    is_terminals = buffer.is_terminals[to_end:] 
    next_states = buffer.next_states[to_end:] 

    #delete the flattened ones from buffer
    del buffer.states[to_end:] 
    del buffer.actions[to_end:] 
    del buffer.logprobs[to_end:] 
    del buffer.rewards[to_end:] 
    del buffer.is_terminals[to_end:] 
    del buffer.next_states[to_end:] 

    def flat_ll(col):
        list_of_l = [list(el) for el in col]#each tensor is converted into a list of (row)tensors=> list of lists of rowtensors
        list_of_l = list(zip(*list_of_l)) #retain the order of episodes and transitions (in the next row)
        flat_list = [row_tensor for sublist in list_of_l for row_tensor in sublist]
        flat_list = [el.unsqueeze(0) if isinstance(el, torch.Tensor) else el.reshape((1)) for el in flat_list]
        return flat_list

    buffer.states += flat_ll(states) 
    buffer.actions += flat_ll(actions) 
    buffer.logprobs += flat_ll(logprobs)
    buffer.rewards += flat_ll(rewards) 
    buffer.is_terminals += flat_ll(is_terminals) 
    buffer.next_states += flat_ll(next_states) 


def beta_params(mean, var):
    a = ((- 1 / mean) + (1-mean) / var)*mean**2
    b = a * (1/mean -1)
    return a, b

def argmax_onehot(vec):
    #columnwise argmax to onehot vec
    ret = np.zeros_like(vec)
    ret[np.arange(vec.shape[0]), vec.argmax(1)] = 1.0
    return ret

