import functools
import numpy as np
import torch
from garage.experiment import deterministic

from garage.sampler import DefaultWorker

from iod.utils import get_np_concat_obs


class RandomAgent(object):
    def __init__(self, action_space):
        self.action_space = action_space

    def get_action(self, obs):
        return self.action_space.sample(), {}


class OptionWorker(DefaultWorker):
    def __init__(
        self,
        *,  # Require passing by keyword, since everything's an int.
        seed,
        max_path_length,
        worker_number,
        sampler_key,
        return_sim_states=0,
    ):
        super().__init__(
            seed=seed, max_path_length=max_path_length, worker_number=worker_number
        )
        self._sampler_key = sampler_key
        self._max_path_length_override = None
        self._cur_extras = None
        self._cur_extra_idx = None
        self._cur_extra_keys = set()
        self._render = False
        self._deterministic_policy = None
        self._prev_hidden_states = None

        self.return_sim_states = True if return_sim_states > 0 else False
        self.random_agent = None
        self.cur_exploration = 0
        self.hierarchical_cooltime = 0
        self._prev_low_option = None

        self._prev_exp_option = None

    def update_agent(self, agent_update):
        """Update an agent, assuming it implements garage.Policy.

        Args:
            agent_update (np.ndarray or dict or garage.Policy): If a
                tuple, dict, or np.ndarray, these should be parameters to
                agent, which should have been generated by calling
                `policy.get_param_values`. Alternatively, a policy itself. Note
                that other implementations of `Worker` may take different types
                for this parameter.

        """
        if isinstance(agent_update, (dict, tuple, np.ndarray)):
            self.agent.set_param_values(agent_update)
        elif agent_update is not None:
            self.agent = agent_update

    def update_env(self, env_update):
        if env_update is not None:
            if isinstance(env_update, dict):
                for k, v in env_update.items():
                    setattr(self.env, k, v)
            else:
                super().update_env(env_update)
        self.random_agent = RandomAgent(self.env.action_space)

    def worker_init(self):
        """Initialize a worker."""
        if self._seed is not None:
            deterministic.set_seed(self._seed + self._worker_number * 10000)

    def update_worker(self, worker_update):
        if worker_update is not None:
            if isinstance(worker_update, dict):
                for k, v in worker_update.items():
                    setattr(self, k, v)
                    if k == "_cur_extras":
                        if v is None:
                            self._cur_extra_keys = set()
                        else:
                            if len(self._cur_extras) > 0:
                                self._cur_extra_keys = set(self._cur_extras[0].keys())
                            else:
                                self._cur_extra_keys = None

            else:
                raise TypeError("Unknown worker update type.")

    def get_attrs(self, keys):
        attr_dict = {}
        for key in keys:
            attr_dict[key] = functools.reduce(getattr, [self] + key.split("."))
        return attr_dict

    def start_rollout(self):
        """Begin a new rollout."""
        if "goal" in self._cur_extra_keys:
            goal = self._cur_extras[self._cur_extra_idx]["goal"]
            reset_kwargs = {"goal": goal}
        else:
            reset_kwargs = {}

        env = self.env
        while hasattr(env, "env"):
            env = getattr(env, "env")

        self._path_length = 0
        self._prev_obs = self.env.reset(**reset_kwargs)
        self._prev_extra = None
        self.cur_exploration = 0
        self.goal_reached = 0

        self.hierarchical_cooltime = 0
        self._prev_low_option = None
        self._prev_exp_option = None

        if "exploration_type" in self._cur_extras[self._cur_extra_idx]:
            exploration_type = self._cur_extras[self._cur_extra_idx]["exploration_type"]
            if exploration_type in [13, 14]:
                self._starting_time = np.random.randint(0, self._max_path_length)
            elif exploration_type == 15:
                self._starting_time = 0
            else:
                self._starting_time = None

        if isinstance(self.agent, dict):
            for k, v in self.agent.items():
                v.reset()
        else:
            self.agent.reset()

    def step_rollout(self):
        """Take a single time-step in the current rollout.

        Returns:
            bool: True iff the path is done, either due to the environment
            indicating termination of due to reaching `max_path_length`.

        """
        cur_max_path_length = (
            self._max_path_length
            if self._max_path_length_override is None
            else self._max_path_length_override
        )

        if self._path_length < cur_max_path_length:  # this is not while loop
            prev_exploration = self.cur_exploration

            cur_exploration = 0
            agent = self.agent.default_policy

            if "use_start_default_policy" in self._cur_extras[self._cur_extra_idx]:
                use_start_default_policy = self._cur_extras[self._cur_extra_idx][
                    "use_start_default_policy"
                ][self._path_length]
                assert isinstance(use_start_default_policy, bool)
            else:
                use_start_default_policy = False
            if use_start_default_policy:
                agent = self.agent.start_default_policy

            exploration_type = self._cur_extras[self._cur_extra_idx].get(
                "exploration_type", None
            )

            if self.goal_reached:
                # always use exploration policy. Goal reaching policy only.
                cur_exploration = 1
                agent = self.agent.exploration_policy
            elif exploration_type is not None:
                # exploration
                # type 1: random exploration for the last 10 steps
                # type 2: use exploration policy for last 10 steps
                # type 3: eps-random (eps: 0.05)
                # type 4: eps-exploration (eps: 0.05)
                # type 5: geometric eps-random, and if random is done, explore until the end (eps: 1 / 190)
                # type 6: geometric eps-exploration, and if exploration is done, explore until the end (1 / 190)
                # type 7: eps-random and when random, execute 5 steps (eps: 0.05 / 5)
                # type 8: eps-exploration and when exploration, execute 5 steps (eps: 0.05 / 5)
                # # type 9: eps-exploration, but that exploration policy uses traj encoder representation
                # type 10: always use the exploration policy (debugging purpose)
                # for APT
                # type 11: eps-random and when random, execute 20 random policy steps (eps: 0.05)
                # type 12: eps-random and when random, execute 20 random steps (eps: 0.05)
                # type 13: true geometric eps-exploration, and if exploration is done, explore until the end
                # type 14: true geometric eps-random, and if exploration is done, explore until the end
                if exploration_type == 0:
                    # exploration type 0: no exploration
                    cur_exploration = 0
                    agent = self.agent.default_policy

                elif exploration_type == 1:
                    # exploration type 1: random exploration for the last 10 steps
                    if self._path_length >= cur_max_path_length - 10:
                        cur_exploration = 1
                        agent = self.random_agent
                elif exploration_type == 2:
                    # exploration type 2: use exploration policy for last 10 steps
                    if self._path_length >= cur_max_path_length - 10:
                        cur_exploration = 1
                        agent = self.agent.exploration_policy
                elif exploration_type == 3:
                    # exploration type 3: eps-random
                    if np.random.random() < 0.05:
                        cur_exploration = 1
                        agent = self.random_agent
                elif exploration_type == 4:
                    # exploration type 4: eps-exploration
                    if np.random.random() < 0.05:
                        cur_exploration = 1
                        agent = self.agent.exploration_policy
                elif exploration_type == 5:
                    # exploration type 5: geometric eps-random, and if random is done, explore until the end
                    if prev_exploration or np.random.random() < 1 / 190:
                        cur_exploration = 1
                        agent = self.random_agent
                elif exploration_type == 6:
                    # exploration type 6: geometric eps-exploration, and if exploration is done, explore until the end
                    if prev_exploration or np.random.random() < 1 / 190:
                        cur_exploration = 1
                        agent = self.agent.exploration_policy
                elif exploration_type == 7:
                    # exploration type 7: eps-random and when random, execute 5 steps
                    if prev_exploration > 0:
                        cur_exploration = prev_exploration - 1
                    if cur_exploration > 0:
                        agent = self.random_agent
                    elif np.random.random() < 0.05 / 5:
                        cur_exploration = 5
                        agent = self.random_agent
                elif exploration_type == 8:
                    # exploration type 8: eps-exploration and when exploration, execute 5 steps
                    if prev_exploration > 0:
                        cur_exploration = prev_exploration - 1
                    if cur_exploration > 0:
                        agent = self.agent.exploration_policy
                    elif np.random.random() < 0.05 / 5:
                        cur_exploration = 5
                        agent = self.agent.exploration_policy
                elif exploration_type == 9:
                    raise NotImplementedError  # not used, implemented as argument

                elif exploration_type == 10:
                    # exploration type 10: always use the exploration policy (debugging purpose)
                    cur_exploration = 1
                    agent = self.agent.exploration_policy

                elif exploration_type == 11:
                    # exploration type 11: eps-exploration and when exploration, execute 20 steps
                    if prev_exploration > 0:
                        cur_exploration = prev_exploration - 1
                    if cur_exploration > 0:
                        agent = self.agent.exploration_policy
                    elif np.random.random() < 0.05:
                        cur_exploration = 20
                        agent = self.agent.exploration_policy
                elif exploration_type == 12:
                    # exploration type 11: eps-exploration and when exploration, execute 20 steps
                    if prev_exploration > 0:
                        cur_exploration = prev_exploration - 1
                    if cur_exploration > 0:
                        agent = self.random_agent
                    elif np.random.random() < 0.05:
                        cur_exploration = 20
                        agent = self.random_agent

                elif exploration_type == 13:
                    # exploration type 13:
                    # true geometric eps-exploration, and if exploration is done, explore until the end
                    assert self._starting_time is not None
                    if self._path_length >= self._starting_time:
                        cur_exploration = 1
                        agent = self.agent.exploration_policy

                elif exploration_type == 14:
                    # exploration type 13:
                    # true geometric eps-exploration, and if exploration is done, explore until the end
                    assert self._starting_time is not None
                    if self._path_length >= self._starting_time:
                        cur_exploration = 1
                        agent = self.random_agent

                elif exploration_type == 15:
                    # exploration type 15:
                    # same as 13 but starting time is 0
                    assert self._starting_time == 0, self._starting_time

                    if self._path_length >= self._starting_time:
                        cur_exploration = 1
                        agent = self.agent.exploration_policy
                    assert (
                        agent is self.agent.exploration_policy
                    ), self._path_length  # always use exploration policy

                else:
                    raise ValueError("Unknown exploration type", exploration_type)
                assert agent is not None

            if "option" in self._cur_extra_keys:
                cur_extra_key = "option"
            else:
                cur_extra_key = None

            if cur_extra_key is None:
                raise
                agent_input = self._prev_obs
            else:
                if isinstance(
                    self._cur_extras[self._cur_extra_idx][cur_extra_key], list
                ):
                    cur_extra = self._cur_extras[self._cur_extra_idx][cur_extra_key][
                        self._path_length
                    ]
                    if cur_extra is None:
                        cur_extra = self._prev_extra
                        self._cur_extras[self._cur_extra_idx][cur_extra_key][
                            self._path_length
                        ] = cur_extra
                else:
                    cur_extra = self._cur_extras[self._cur_extra_idx][cur_extra_key]
                cur_extra = (
                    cur_extra
                    if not hasattr(self.env, "desired_goal")
                    else self.env.desired_goal
                )

                exp_option = self._cur_extras[self._cur_extra_idx].get(
                    "exp_option", None
                )
                prevupd = exp_option is not None
                if cur_exploration > 0:
                    assert agent is not self.agent.default_policy, agent
                    if prevupd:  # if prevupd and also cur exploration:
                        agent_input = get_np_concat_obs(
                            self._prev_obs,
                            exp_option,
                        )
                    else:
                        agent_input = self._prev_obs

                elif self.agent.hierarchical_policy is not None:
                    # if hierarchical, generate some skills first
                    hierarchical_agent_input = get_np_concat_obs(
                        self._prev_obs,
                        cur_extra,
                    )
                    if self.hierarchical_cooltime > 0:
                        self.hierarchical_cooltime -= 1
                        low_option = self._prev_low_option
                    else:
                        self.hierarchical_cooltime = 0
                        low_option, _ = self.agent.hierarchical_policy.get_action(
                            hierarchical_agent_input
                        )
                    agent_input = get_np_concat_obs(
                        self._prev_obs,
                        low_option,
                    )
                    self._prev_low_option = low_option
                else:
                    assert agent is self.agent.default_policy, agent
                    agent_input = get_np_concat_obs(
                        self._prev_obs,
                        cur_extra,
                    )

                self._prev_extra = cur_extra

            if self._deterministic_policy is not None:
                agent._force_use_mode_actions = self._deterministic_policy
                # if agent is exploration policy, should not be determinsitic
            if agent is self.agent.exploration_policy:
                assert not agent._force_use_mode_actions, agent._force_use_mode_actions

            # select action
            if self._prev_hidden_states is not None:
                a, agent_info = agent.get_action(
                    agent_input, hidden_states=self._prev_hidden_states
                )
            else:
                a, agent_info = agent.get_action(agent_input)

            if "hidden_states" in agent_info:
                self._prev_hidden_states = agent_info.pop("hidden_states")

            if self.return_sim_states:
                prev_sim_state = self.env.sim.get_state().flatten()

            if self._render:
                next_o, r, d, env_info = self.env.step(a, render=self._render)
            else:
                next_o, r, d, env_info = self.env.step(a)

            if self.return_sim_states:
                env_info["sim_states"] = prev_sim_state

            self._observations.append(self._prev_obs)
            self._rewards.append(r)
            self._actions.append(a)

            agent_info = {}

            if self.agent.hierarchical_policy is not None:
                agent_info["low_options"] = low_option

            agent_info["cur_exploration"] = cur_exploration > 0
            self.cur_exploration = cur_exploration

            # we do not use agent_info.
            for k, v in agent_info.items():
                self._agent_infos[k].append(v)

            for k in self._cur_extra_keys:
                if isinstance(self._cur_extras[self._cur_extra_idx][k], list):
                    self._agent_infos[k].append(
                        self._cur_extras[self._cur_extra_idx][k][self._path_length]
                    )
                else:
                    self._agent_infos[k].append(
                        self._cur_extras[self._cur_extra_idx][k]
                    )

            # check if it reached the goal or not
            if self.agent.encoder is not None and not self._deterministic_policy:
                if not self.goal_reached:
                    if self._path_length % 10 == 0:
                        encoder = self.agent.encoder
                        last_obs = np.stack(self._observations[-1:], axis=0)
                        goal_obs = cur_extra[None]
                        assert last_obs.ndim == goal_obs.ndim == 2
                        enc_inp = torch.from_numpy(
                            np.concatenate([last_obs, goal_obs], axis=0)
                        ).to(next(encoder.parameters()).device)
                        # enc_inp = torch.from_numpy(np.stack([next_o, cur_extra], axis=0)).to(
                        #     next(encoder.parameters()).device
                        # )
                        res = encoder(enc_inp).mean
                        assert res.ndim == 2
                        last_latent = res[-2:-1]
                        goal_latent = res[-1:]
                        dist_to_goal = torch.norm(last_latent - goal_latent, p=2, dim=1)
                        succeeded = (dist_to_goal < 3).any()
                        assert succeeded.numel() == 1
                        if succeeded:
                            # goal reached, execute exploration policy
                            self.goal_reached = 1
                            # print("Goal reached in step", self._path_length)

            for k, v in env_info.items():
                self._env_infos[k].append(v)
            self._path_length += 1
            self._terminals.append(d)
            if not d:
                self._prev_obs = next_o
                return False
        self._terminals[-1] = True
        self._lengths.append(self._path_length)
        self._last_observations.append(self._prev_obs)
        return True

    def rollout(self):
        """Sample a single rollout of the agent in the environment.

        Returns:
            garage.TrajectoryBatch: The collected trajectory.

        """
        if self._cur_extras is not None:
            self._cur_extra_idx += 1
        self.start_rollout()
        while not self.step_rollout():
            pass
        batch = self.collect_rollout()
        assert np.sum(batch.terminals) <= 1
        return batch
