"""Default Worker class."""
from collections import defaultdict

import numpy as np

from garage import EpisodeBatch, StepType
from garage.experiment import deterministic
from garage.sampler import _apply_env_update
from garage.sampler.worker import Worker


class DefaultWorker(Worker):
    """Initialize a worker.

    Args:
        seed (int): The seed to use to intialize random number generators.
        max_episode_length (int or float): The maximum length of episodes which
            will be sampled. Can be (floating point) infinity.
        worker_number (int): The number of the worker where this update is
            occurring. This argument is used to set a different seed for each
            worker.

    Attributes:
        agent (Policy or None): The worker's agent.
        env (Environment or None): The worker's environment.

    """

    def __init__(
            self,
            *,  # Require passing by keyword, since everything's an int.
            seed,
            max_episode_length,
            worker_number):
        super().__init__(seed=seed,
                         max_episode_length=max_episode_length,
                         worker_number=worker_number)
        self.agent = None
        self.env = None
        self._env_steps = []
        self._observations = []
        self._last_observations = []
        self._agent_infos = defaultdict(list)
        self._lengths = []
        self._prev_obs = None
        self._eps_length = 0
        self._episode_infos = defaultdict(list)
        self.worker_init()

    def worker_init(self):
        """Initialize a worker."""
        if self._seed is not None:
            deterministic.set_seed(self._seed + self._worker_number)

    def update_agent(self, agent_update):
        """Update an agent, assuming it implements :class:`~Policy`.

        Args:
            agent_update (np.ndarray or dict or Policy): If a tuple, dict, or
                np.ndarray, these should be parameters to agent, which should
                have been generated by calling `Policy.get_param_values`.
                Alternatively, a policy itself. Note that other implementations
                of `Worker` may take different types for this parameter.

        """
        if isinstance(agent_update, (dict, tuple, np.ndarray)):
            self.agent.set_param_values(agent_update)
        elif agent_update is not None:
            self.agent = agent_update

    def update_env(self, env_update):
        """Use any non-None env_update as a new environment.

        A simple env update function. If env_update is not None, it should be
        the complete new environment.

        This allows changing environments by passing the new environment as
        `env_update` into `obtain_samples`.

        Args:
            env_update(Environment or EnvUpdate or None): The environment to
                replace the existing env with. Note that other implementations
                of `Worker` may take different types for this parameter.

        Raises:
            TypeError: If env_update is not one of the documented types.

        """
        self.env, _ = _apply_env_update(self.env, env_update)

    def start_episode(self):
        """Begin a new episode."""
        self._eps_length = 0
        self._prev_obs, episode_info = self.env.reset()
        for k, v in episode_info.items():
            self._episode_infos[k].append(v)

        self.agent.reset()

    def step_episode(self, deterministic=False):
        """Take a single time-step in the current episode.

        Returns:
            bool: True iff the episode is done, either due to the environment
            indicating termination of due to reaching `max_episode_length`.

        """
        if self._eps_length < self._max_episode_length:
            a, agent_info = self.agent.get_action(self._prev_obs, deterministic)
            es = self.env.step(a)
            self._observations.append(self._prev_obs)
            self._env_steps.append(es)
            for k, v in agent_info.items():
                self._agent_infos[k].append(v)
            self._eps_length += 1

            if not es.terminal:
                self._prev_obs = es.observation
                return False
        self._lengths.append(self._eps_length)
        self._last_observations.append(self._prev_obs)
        return True

    def collect_episode(self):
        """Collect the current episode, clearing the internal buffer.

        Returns:
            EpisodeBatch: A batch of the episodes completed since the last call
                to collect_episode().

        """
        observations = self._observations
        self._observations = []
        last_observations = self._last_observations
        self._last_observations = []

        actions = []
        rewards = []
        env_infos = defaultdict(list)
        step_types = []

        for es in self._env_steps:
            actions.append(es.action)
            rewards.append(es.reward)
            step_types.append(es.step_type)
            for k, v in es.env_info.items():
                env_infos[k].append(v)
        self._env_steps = []

        agent_infos = self._agent_infos
        self._agent_infos = defaultdict(list)
        for k, v in agent_infos.items():
            agent_infos[k] = np.asarray(v)

        for k, v in env_infos.items():
            env_infos[k] = np.asarray(v)

        episode_infos = self._episode_infos
        self._episode_infos = defaultdict(list)
        for k, v in episode_infos.items():
            episode_infos[k] = np.asarray(v)

        lengths = self._lengths
        self._lengths = []
        return EpisodeBatch(env_spec=self.env.spec,
                            episode_infos=episode_infos,
                            observations=np.asarray(observations),
                            last_observations=np.asarray(last_observations),
                            actions=np.asarray(actions),
                            rewards=np.asarray(rewards),
                            step_types=np.asarray(step_types, dtype=StepType),
                            env_infos=dict(env_infos),
                            agent_infos=dict(agent_infos),
                            lengths=np.asarray(lengths, dtype='i'))

    def rollout(self):
        """Sample a single episode of the agent in the environment.

        Returns:
            EpisodeBatch: The collected episode.

        """
        self.start_episode()
        while not self.step_episode():
            pass
        return self.collect_episode()

    def shutdown(self):
        """Close the worker's environment."""
        self.env.close()
