from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np
import gym
import cv2

from easydict import EasyDict as edict
import marlgrid.envs
import marlgrid

def make_environment(env_cfg, lock=None):
    """ Use this to make Environments """

    env_name = env_cfg.env_name

    assert env_name.startswith('MarlGrid')
    env = create_grid_world_env(env_cfg)
    # env = GridWorldEvaluatorWrapper(env)
    env = DictObservationNormalizationWrapper(env)
    env = SimplifiedObservationWrapper(env)

    return env

def create_grid_world_env(env_cfg):
    """
    Automatically generate env instance from env configs.
    """
    env_name = get_env_name(env_cfg)

    env = marlgrid.envs.register_env(
        env_name=env_name,
        n_agents=env_cfg.num_agents,
        grid_size=env_cfg.grid_size,
        view_size=env_cfg.view_size,
        view_tile_size=env_cfg.view_tile_size,
        comm_dim=2,
        comm_len=env_cfg.comm_len,
        discrete_comm=env_cfg.discrete_comm,
        n_adversaries=0,
        observation_style=env_cfg.observation_style,
        observe_position=env_cfg.observe_position,
        observe_self_position=env_cfg.observe_self_position,
        observe_done=env_cfg.observe_done,
        observe_self_env_act=env_cfg.observe_self_env_act,
        observe_t=env_cfg.observe_t,
        neutral_shape=env_cfg.neutral_shape,
        can_overlap=env_cfg.can_overlap,
        use_gym_env=False,
        env_configs={
            'max_steps': env_cfg.max_steps,
            'team_reward_multiplier': env_cfg.team_reward_multiplier,
            'team_reward_type': env_cfg.team_reward_type,
            'team_reward_freq': env_cfg.team_reward_freq,
            'seed': env_cfg.seed,
            'active_after_done': env_cfg.active_after_done,
            'discrete_position': env_cfg.discrete_position,
            'separate_rew_more': env_cfg.separate_rew_more,
            'info_gain_rew': env_cfg.info_gain_rew,
        },
        clutter_density=env_cfg.clutter_density)

    return env


def get_env_name(env_cfg):
    """
    Automatically generate env name from env configs.
    """
    assert env_cfg.env_type == 'c'
    name = f'MarlGrid-{env_cfg.num_agents}Agent'

    if env_cfg.comm_len > 0:
        name += f'{env_cfg.comm_len}C'
        if not env_cfg.discrete_comm:
            name += 'cont'

    if env_cfg.view_size != 7:
        name += f'{env_cfg.view_size}Vs'

    name += f'{env_cfg.grid_size}x{env_cfg.grid_size}-v0'
    return name


class DictObservationNormalizationWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        return

    def step(self, action):
        obs_dict, rew_dict, done_dict, info_dict = self.env.step(action)
        for k, v in obs_dict.items():
            if k == 'global':
                continue

            if isinstance(v, dict):
                obs_dict[k]['pov'] = (2. * ((v['pov'] / 255.) - 0.5))
            else:
                obs_dict[k] = (2. * ((v / 255.) - 0.5))
        return obs_dict, rew_dict, done_dict, info_dict


class SimplifiedObservationWrapper(gym.Wrapper):
    """
    This wrapper converts the dict observation space to tuple to make it easier to integrate to this codebase
    """
    def __init__(self, env):
        super().__init__(env)
        self.simplified_obs_space = []
        self.new_action_space = []
        for a_i in range(len(self.agents)):
            self.simplified_obs_space.append(gym.spaces.Tuple((self.observation_space['pov'], self.observation_space['selfpos'])))
            self.new_action_space.append(self.action_space)
        self.action_space = tuple(self.new_action_space)
        self.observation_space = self.simplified_obs_space
        self.step_penalty = -0.01

        return

    def dict_to_tup(self, obs_dict):
        obs_tups = []
        for a_i in range(len(self.agents)):
            obs_tups.append((obs_dict['agent_' + str(a_i)]['pov'], obs_dict['agent_' + str(a_i)]['selfpos']))
        return obs_tups

    def step(self, action):
        action_dict = {}
        for a_i in range(len(self.agents)):
            action_dict['agent_' + str(a_i)] = action[a_i]
        obs_dict, rew_dict, done_dict, info_dict = self.env.step(action_dict)
        reward_arr = np.zeros((len(self.agents)))
        done_arr = []
        for a_i in range(len(self.agents)):
            agent_key = 'agent_' + str(a_i)
            r = rew_dict[agent_key] + self.step_penalty if info_dict[agent_key]['done'] == False else rew_dict[agent_key]
            reward_arr[a_i] = r
            done_arr.append(info_dict[agent_key]['done'])
        obs_tups = self.dict_to_tup(obs_dict)
        return obs_tups, reward_arr, done_arr, info_dict

    def reset(self):
        obs_dict = self.env.reset()
        return self.dict_to_tup(obs_dict)

class GridWorldEvaluatorWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.video_scale = 2
        self.show_reward = True
        return

    def get_raw_obs(self):
        frame = self.env.render(mode='rgb_array', show_more=self.show_reward)
        frame = frame.astype(np.uint8)
        frame = self.resize(frame)
        obs, rew, _, info = self._out

        frame = self.render_reward(frame, rew, info)
        return frame

    def step(self, action):
        self._out = self.env.step(action)
        return self._out

    def resize(self, frame):
        if self.video_scale != 1:
            frame = cv2.resize(frame, None,
                               fx=self.video_scale,
                               fy=self.video_scale,
                               interpolation=cv2.INTER_AREA)
        return frame

    def render_reward(self, frame, reward_dict, info_dict):
        if self.show_reward:
            # render reward
            to_render = ['env rew',
                         *[f'{k[0] + k[-1]}: {v:.3f}' for k, v in
                           reward_dict.items()],
                         'comm rew',
                         *[f'{k[0] + k[-1]}: {v:.3f}' for k, v in
                           info_dict['rew_by_act'][1].items()]
                         ]

            # render communication
            for k, v in info_dict.items():
                if k[-1].isdigit():
                    to_render += [str(k[0] + k[-1]) + ': ' + str(v['comm'])]

            for k, v in info_dict.items():
                if k[-1].isdigit() and v['comm_str'] != '':
                    to_render += [(str(k[0] + k[-1]) + ' ---'),
                                  *(v['comm_str'])]

            str_spacing = 30
            x_start = ((frame.shape[1] - frame.shape[0]
                        ) // 2) + frame.shape[0] + 10
            y_start = int(0.1 * frame.shape[0])
            for i, text_to_render in enumerate(to_render):
                cv2.putText(frame, text_to_render,
                            (x_start, y_start + (i * str_spacing)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        return frame
