import sys
sys.path.append('/home/lc/MAgent/python')

from src.envs.multiagentenv import MultiAgentEnv
import numpy as np
import torch as th
import magent
from src.utils.dict2namedtuple import convert


def get_config_pursuit_attack(map_size, coop_reward, attack_penalty):
    gw = magent.gridworld
    cfg = gw.Config()

    cfg.set({"map_width": map_size, "map_height": map_size})
    # attack_penalty default -0.15
    # 'speed' means movement range, and 'attack_range' means attack range
    predator = cfg.register_agent_type("predator", {'width': 1, 'length': 1, 'hp': 1, 'speed': 1,
                                                    'view_range': gw.CircleRange(5), 'attack_range': gw.CircleRange(1),
                                                    'attack_penalty': attack_penalty})

    prey = cfg.register_agent_type("prey", {'width': 1, 'length': 1, 'hp': 1, 'speed': 0,
                                            'view_range': gw.CircleRange(4), 'attack_range': gw.CircleRange(0)})

    predator_group = cfg.add_group(predator)
    prey_group = cfg.add_group(prey)

    a = gw.AgentSymbol(predator_group, index='any')
    b = gw.AgentSymbol(predator_group, index='any')
    c = gw.AgentSymbol(prey_group, index='any')

    e1 = gw.Event(a, 'attack', c)
    e2 = gw.Event(b, 'attack', c)
    cfg.add_reward_rule(e1 & e2, receiver=[a, b], value=[coop_reward, coop_reward])     # Any two predators simultaneously attack the prey, then they receive 0.5+0.5 total reward.

    return cfg


class Lift(MultiAgentEnv):
    def __init__(self, **kwargs):
        # Unpack arguments from sacred
        args = kwargs["env_args"]
        if isinstance(args, dict):
            args = convert(args)
        self.args = args
        self.map_size = int(args.map_size)
        self.map_name = args.map_name
        self.coop_reward = args.coop_reward
        self.attack_penalty = args.attack_penalty
        self.n_agents = int(args.n_agents)
        self.more_walls = 0
        self.more_enemy = 0
        self.mini_map_shape = args.mini_map_shape
        self.run_time = '0'

        self.env = magent.GridWorld(get_config_pursuit_attack(self.map_size, self.coop_reward, self.attack_penalty))
        self.env.set_seed(args.seed)
        self.handles = self.env.get_handles()

        feature_dim = self.env.get_feature_space(self.handles[0])
        view_dim = self.env.get_view_space(self.handles[0])
        v_dim_total = view_dim[0] * view_dim[1] * view_dim[2]       # width*height*n_channels
        self.obs_shape = int(v_dim_total + feature_dim[0])
        self.state_shape = int((self.mini_map_shape * self.mini_map_shape) * 2)  # The latter 2 represents 2 groups?
        self.n_actions = self.env.action_space[0][0]        # 9
        # Prey's action space
        self.fixed_n_actions = self.env.action_space[1][0]      # 1, do nothing
        self.episode_limit = 100

        print("Predator's action space:", self.n_actions)
        print("Prey's action space:", self.fixed_n_actions)

        self.enemy_feats_dim = 0
        self.pos_dim = 2
        self.use_other_feature = True

        self.steps = 0

    def get_state(self):
        state = self.env.get_global_minimap(self.mini_map_shape, self.mini_map_shape).flatten()
        return state

    def get_obs(self):
        obs_all = self.env.get_observation(self.handles[0])
        # fixed_obs_all = self.env.get_observation(self.handles[1])
        view = obs_all[0]
        feature = obs_all[1]
        # fixed_view = fixed_obs_all[0]
        # fixed_feature = fixed_obs_all[1]
        obs = []
        # fixed_obs = []
        for i in range(self.n_agents):
            obs.append(np.concatenate([view[i].flatten(), feature[i]]))
        return obs

    def reset(self):
        self.env.reset()
        self.steps = 0
        handles = self.env.get_handles()
        self.env.add_walls(method="random", n=self.n_agents * 2 * self.more_walls)
        self.env.add_agents(handles[0], method="random", n=self.n_agents)
        self.env.add_agents(handles[1], method="random", n=self.n_agents + self.more_enemy)
        return self.get_obs(), self.get_state()

    def step(self, actions):
        curr_num_agents = self.env.get_num(self.handles[0])
        # fixed_num_agents = self.env.get_num(self.handles[1])
        if curr_num_agents < self.n_agents:
            self.env.add_agents(self.handles[0], method="random", n=self.n_agents - curr_num_agents)
        total_actions = [[], []]
        actions_array = []
        for index, act in enumerate(actions):
            actions_array.append(act.numpy().astype(np.int32) if th.is_tensor(act) else act.astype(np.int32))
        total_actions[0] = np.array(actions_array)
        total_actions[1] = np.array(np.random.randint(0, self.fixed_n_actions, size=self.env.get_num(self.handles[1]), dtype='int32'))
        self.env.set_action(self.handles[0], total_actions[0])
        self.env.set_action(self.handles[1], total_actions[1])
        terminated = self.env.step()
        reward = np.sum(self.env.get_reward(self.handles[0]))
        # fixed_reward = np.sum(self.env.get_reward(self.handles[1]))
        self.env.clear_dead()

        if self.steps + 1 == self.episode_limit:
            terminated = 1.
        self.steps += 1

        info = {}

        return reward, terminated, info

    def get_total_actions(self):
        return self.n_actions

    def get_avail_actions(self):
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_actions.append(np.ones(self.n_actions))
        return avail_actions

    def get_obs_size(self):
        return self.obs_shape

    def get_state_size(self):
        return self.state_shape

    def get_stats(self):
        pass

    def get_env_info(self):
        info = MultiAgentEnv.get_env_info(self)
        return info

    def close(self):
        pass

    def render(self):
        # TODO!
        pass

    def seed(self):
        self.env.set_seed(self.args.seed)