from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from gym.spaces import Discrete, Box, MultiDiscrete
from ray import rllib

from env.gym_fortattack.fortattack import make_fortattack_env

import numpy as np
import time
from skimage.transform import rescale, resize, downscale_local_mean
from PIL import Image, ImageFilter
from ray.rllib import MultiAgentEnv
from ray.rllib.env.apis.task_settable_env import TaskSettableEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.annotations import override
import random


# A wrapper class from rllib 
class RLlibMultiFortAttack(MultiAgentEnv, TaskSettableEnv):
    """Wraps OpenAI Multi-Agent Particle env to be compatible with RLLib multi-agent."""

    metadata = {
        "render.modes": ["rgb_array"],
    }

    def __init__(self, **fortattack_args):
        """Create a new Multi-Agent Particle env compatible with RLlib.

        Arguments:
            mpe_args (dict): Arguments to pass to the underlying
                make_env.make_env instance.

        Examples:
            >>> from rllib_env import RLlibMultiAgentParticleEnv
            >>> env = RLlibMultiAgentParticleEnv(scenario_name="simple_reference")
            >>> print(env.reset())
        """

        self.return_image = fortattack_args["return_image"]
        del fortattack_args["name"]
        self.current_concepts = fortattack_args["current_concepts"]
        del fortattack_args["current_concepts"]
        self._env = make_fortattack_env(**fortattack_args)

        self.num_agents = self._env.n
        self.num_attackers = self._env.num_attackers
        self.num_guards = self._env.num_guards
        self.already_done = {i: False for i in range(self.num_agents)}
        self._agent_ids = list(range(self.num_agents))

        self.observation_space_dict = self._make_dict(self._env.observation_space)
        self.action_space_dict = self._make_dict(self._env.action_space)

        if fortattack_args["multi_discrete"]:
            self.action_len = self._env.action_space[0].shape[0]
        else:
            self.action_len = 1

        self.hard_coded_paths = fortattack_args["hard_coded_paths"]
        self.use_hard_coded_paths = fortattack_args["use_hard_coded_paths"]
        self.hard_coded_per_agent = {i: None for i in range(self.num_agents)}
        self.hard_coded_paths_current = [0 for _ in range(self.num_agents)]
        self.stratagy = np.random.choice(["all left", "all right"])
        self.gen_per_agent_path()

        self.spawn_pos = fortattack_args["default_spawn_pos"]
        self.cur_level = 3
        self.return_all = False

    @override(TaskSettableEnv)
    def sample_tasks(self, n_tasks):
        """Implement this to sample n random tasks."""
        return [random.randint(1, 4) for _ in range(n_tasks)]

    @override(TaskSettableEnv)
    def get_task(self):
        """Implement this to get the current task (curriculum level)."""
        return self.cur_level

    @override(TaskSettableEnv)
    def set_task(self, task):
        """Implement this to set the task (curriculum level) for this env."""
        self.cur_level = task

    def _singleobstomulti(self, obs_list):
        out_list = []
        for i in range(self.num_agents):
            out_list.append(
                [obs_list[i]] + [obs_list[j] for j in range(self.num_agents) if j != i]
            )
        out_list = np.array(out_list)
        obs_list = obs_list.reshape(*obs_list.shape[:-2], -1)
        return out_list

    def gen_per_agent_path(self):

        for i in range(self.num_guards, self.num_agents):
            j = None
            if self.stratagy == "all left":
                j = 1
            elif self.stratagy == "all right":
                j = 0
            else:
                assert False
            # j = np.random.choice([i for i in range(len(self.hard_coded_paths))])
            k = np.random.uniform(0, 1)
            self.hard_coded_per_agent[i] = [
                [k * a, b] for a, b in self.hard_coded_paths[j]
            ]
        for i in range(self.num_guards, self.num_agents):
            while (
                self.hard_coded_per_agent[i][self.hard_coded_paths_current[i]][1]
                < self._env.world.agents[i].state.p_pos[1]
                and self.hard_coded_paths_current[i]
                < len(self.hard_coded_per_agent[i]) - 1
            ):
                self.hard_coded_paths_current[i] += 1

    def _renderObservations(self):
        t1 = time.time()
        images = self._env.render(mode="rgb", render_multiple=True)
        # return np.array(Image.fromarray(images[0], 'RGB').filter(ImageFilter.GaussianBlur(radius = 0.5)))
        return [
            (
                np.array(Image.fromarray(image, "RGB")).transpose((2, 0, 1)) / 255.0
            ).astype(np.float32)
            for image in images
        ]

    def reset(self):
        """Resets the env and returns observations from ready agents.

        Returns:
            obs_dict: New observations for each ready agent.
        """
        if self.cur_level == 1:
            self._env.spawn_pos = "random"
            self._env.using_hard_coded_paths = False
            self.use_hard_coded_paths = False
        elif self.cur_level == 2:
            self._env.spawn_pos = "opposite_ends"
            self.use_hard_coded_paths = False
            self._env.using_hard_coded_paths = False
        elif self.cur_level == 3:
            self._env.spawn_pos = "opposite_ends"
            self.use_hard_coded_paths = True
            self._env.using_hard_coded_paths = True
        elif self.cur_level == 4:
            self._env.spawn_pos = "random"
            self.use_hard_coded_paths = True
            self._env.using_hard_coded_paths = True
        self.already_done = {i: False for i in range(self.num_agents)}
        if self.return_image:
            self._env.reset()
            obs_list = self._renderObservations()
            obs_dict = self._make_dict(obs_list)
        else:
            obs_dict = self._make_dict(self._singleobstomulti(self._env.reset()))

        self.hard_coded_paths_current = [0 for _ in range(self.num_agents)]
        self.stratagy = np.random.choice(["all left", "all right"])
        self.gen_per_agent_path()

        return obs_dict

    def get_concepts(self):
        concepts = self._env.get_concepts()
        if not self.use_hard_coded_paths:
            current_stratagy = 0
        else:
            if self.stratagy == "all left":
                current_stratagy = 1
            elif self.stratagy == "all right":
                current_stratagy = 2
            else:
                assert 0, "stratagy not recognized"
        for i in range(self.num_guards):
            concepts[i]["attacker_stratagy"] = current_stratagy
        final_concepts = []
        for i in range(self.num_agents):
            concat_concepts = []
            if "can_shoot_ordinal" in concepts[i].keys():
                concat_concepts.extend(
                    np.stack(concepts[i]["can_shoot_ordinal"]).reshape(-1)
                )
            if "agent_targeting_ordinal" in concepts[i].keys():
                concat_concepts.extend(np.array(concepts[i]["agent_targeting_ordinal"]))
            if "attacker_stratagy" in concepts[i].keys():
                to_append = [0, 0, 0]
                to_append[concepts[i]["attacker_stratagy"]] = 1
                concat_concepts.extend(np.array(to_append))
            if "relative_orientation" in concepts[i].keys():
                concat_concepts.extend(np.array(concepts[i]["relative_orientation"]))
            if "distance_between" in concepts[i].keys():
                concat_concepts.extend(np.array(concepts[i]["distance_between"]))
            final_concepts.append(concat_concepts.copy())
        return final_concepts

    def step(self, action_dict):
        """Returns observations from ready agents.

        The returns are dicts mapping from agent_id strings to values. The
        number of agents in the env can vary over time.

        Returns:
            obs_dict:
                New observations for each ready agent.
            rew_dict:
                Reward values for each ready agent.
            done_dict:
                Done values for each ready agent.
                The special key "__all__" (required) is used to indicate env termination.
            info_dict:
                Optional info values for each agent id.
        """

        actions = []
        for i in range(self.num_agents):
            if i not in action_dict.keys():
                if i < self.num_guards or not self.use_hard_coded_paths:
                    actions.append([0 for i in range(self.action_len)])
                else:
                    actions.append(0)
            else:
                actions.append(action_dict[i])

        for i in range(self.num_agents):
            if not self.hard_coded_per_agent[i] is None and self.use_hard_coded_paths:
                actions[i] = self.hard_coded_step_i(i)

        (
            raw_obs_list,
            rew_list,
            done_list,
            info_list,
            ground_truth_n,
            concepts,
        ) = self._env.step(actions)

        obs_raw = self._singleobstomulti(raw_obs_list)
        # ground_truth_obs = self._singleobstomulti(ground_truth_n)

        record_just_died = [0 for _ in range(len(self._env.world.agents))]

        for i, entity in list(enumerate(self._env.world.agents)):
            if entity.justDied:
                record_just_died[i] = 1

        result = self._env.world.gameResult

        if self.return_image:
            obs_list = self._renderObservations()
            obs_dict = self._make_dict(obs_list)
            raw_obs_dict = self._make_dict(obs_raw)
        else:
            obs_dict = self._make_dict(obs_raw)
            raw_obs_dict = obs_dict
        # ground_truth_obs_dict = self._make_dict(ground_truth_obs)

        rew_dict = self._make_dict(rew_list)
        info_dict_ = self._make_dict(info_list)
        done_this_step = self.already_done.copy()
        # Rllib only wants agents being done to be reported once, hence this annoying addition
        done_dict = {}
        for i in range(self.num_agents):
            if done_list[i] and not self.already_done[i]:
                done_dict[i] = True
                self.already_done[i] = True

        done_dict["__all__"] = np.all(np.array(done_list))

        # if np.all(np.array(done_list)):
        #    done_dict = {"__all__": True}
        # else:
        #    done_dict["__all__"] = False
        # FIXME: Currently, this is the best option to transfer agent-wise termination signal without touching RLlib code hugely.
        # FIXME: Hopefully, this will be solved in the future
        # torch.tensor(result), torch.tensor(record_just_died), torch.tensor(chosen_attacker), torch.tensor(distances).
        if not self.use_hard_coded_paths:
            current_stratagy = 0
        else:
            if self.stratagy == "all left":
                current_stratagy = 1
            elif self.stratagy == "all right":
                current_stratagy = 2
            else:
                assert 0, "stratagy not recognized"
        for i in range(self.num_guards):
            concepts[i]["attacker_stratagy"] = current_stratagy

        final_concepts = []
        for i in range(self.num_agents):
            concat_concepts = []
            if "can_shoot_ordinal" in concepts[i].keys():
                concat_concepts.extend(
                    np.stack(concepts[i]["can_shoot_ordinal"]).reshape(-1)
                )
            if "agent_targeting_ordinal" in concepts[i].keys():
                concat_concepts.extend(np.array(concepts[i]["agent_targeting_ordinal"]))
            if "attacker_stratagy" in concepts[i].keys():
                to_append = [0, 0, 0]
                to_append[concepts[i]["attacker_stratagy"]] = 1
                concat_concepts.extend(np.array(to_append))
            if "relative_orientation" in concepts[i].keys():
                concat_concepts.extend(np.array(concepts[i]["relative_orientation"]))
            if "distance_between" in concepts[i].keys():
                concat_concepts.extend(np.array(concepts[i]["distance_between"]))
            final_concepts.append(concat_concepts.copy())

        info_dict = self._make_dict(
            [
                {
                    "done": done_list,
                    "result": result,
                    "just_died": record_just_died[i],
                    "raw_obs": raw_obs_dict[i],
                    "ground_truth": ground_truth_n,
                    "info_list": info_dict_,
                    "stratagy": current_stratagy,
                    "raw_actions": actions[i],
                    "ground truth concepts": final_concepts[i],
                }
                for i in range(self.num_agents)
            ]
        )
        if not self.return_all:
            for i in range(self.num_agents):
                if done_this_step[i]:
                    del obs_dict[i]
                    del rew_dict[i]
                    del info_dict[i]

        return obs_dict, rew_dict, done_dict, info_dict

    def render(self, mode="human"):
        return self._env.render(mode=mode)

    def _make_dict(self, values):
        return dict(zip(self._agent_ids, values))

    def hard_coded_step_i(self, i):
        """
        Return the action of agent i using a hard coded policy
        which consists of moving the agent closer to the next target
        out of an ordered list of targets.
        """

        # check if the current position is close enough to the target
        if self.hard_coded_paths_current[i] < len(self.hard_coded_per_agent[i]) - 1:
            # if so, move towards the next target
            if (
                np.linalg.norm(
                    self.hard_coded_per_agent[i][self.hard_coded_paths_current[i]]
                    - self._env.world.agents[i].state.p_pos
                )
                < 0.1
            ):
                self.hard_coded_paths_current[i] += 1
        xgoal, ygoal = self.hard_coded_per_agent[i][self.hard_coded_paths_current[i]]
        xpos, ypos = self._env.world.agents[i].state.p_pos

        xdiff = xgoal - xpos
        ydiff = ygoal - ypos

        if abs(xdiff) > abs(ydiff):
            if xdiff > 0:
                action = 1
            else:
                action = 2
        else:
            if ydiff > 0:
                action = 3
            else:
                action = 4
        return action
