from abc import ABC, abstractmethod
from typing import List, Dict

from collections import deque
import time

import numpy as np
import torch as th

from .util import action_from_policy, clip_actions, resample_noise, action_from_rnn_policy
from .trajsaver import TransitionsMinimal
from .observation import Observation

from stable_baselines3.common.utils import (
    configure_logger,
    should_collect_more_steps
)
from stable_baselines3.common.policies import ActorCriticPolicy
from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.utils import safe_mean

from sb3_contrib.common.recurrent.type_aliases import RNNStates

class Agent(ABC):
    """
    Base class for all agents in multi-agent environments
    """

    @abstractmethod
    def get_action(self, obs: Observation, record: bool = True) -> np.ndarray:
        """
        Return an action given an observation.

        :param obs: The observation to use
        :param record: Whether to record the obs, action pair (for training)
        :returns: The action to take
        """

    @abstractmethod
    def update(self, reward: float, done: bool) -> None:
        """
        Add new rewards and done information if the agent can learn.

        Each update corresponds to the most recent `get_action` (where
        `record` is True). If there are multiple calls to `update` that
        correspond to the same `get_action`, their rewards are summed up and
        the last done flag will be used.

        :param reward: The reward receieved from the previous action step
        :param done: Whether the game is done
        """

class StaticRNNPolicyAgent(Agent):
    def __init__(self, policy: RecurrentActorCriticPolicy):
        self.policy= policy
        self.reset_rnn_state()

    def reset_rnn_state(self):
        self.last_episode_starts = th.tensor(True, dtype=th.float32, device=self.policy.device)
        lstm = self.policy.lstm_actor
        single_hidden_state_shape = (lstm.num_layers, 1, lstm.hidden_size)
        self.rnn_state = RNNStates(
            (
                th.zeros(single_hidden_state_shape, device=self.policy.device),
                th.zeros(single_hidden_state_shape, device=self.policy.device),
            ),
            (
                th.zeros(single_hidden_state_shape, device=self.policy.device),
                th.zeros(single_hidden_state_shape, device=self.policy.device),
            ),
        )

    def get_action(self, obs: Observation, done: th.Tensor, record: bool = True) -> np.ndarray:
        actions, _, log_probs, self.rnn_state = action_from_rnn_policy(obs.obs, self.rnn_state, self.last_episode_starts, self.policy)
        self.last_episode_starts = th.tensor(done, dtype=th.float32, device=self.policy.device)
        return clip_actions(actions, self.policy)[0]
    
    def get_action_probs(self, obs: Observation, done: th.Tensor, record: bool = True) -> np.ndarray:
        _, _, log_probs, self.rnn_state = action_from_rnn_policy(obs.obs, self.rnn_state, self.last_episode_starts, self.policy)
        self.last_episode_starts = th.tensor(done, dtype=th.float32, device=self.policy.device)
        return np.exp(log_probs.cpu().numpy())

    def update(self, reward: float, done: bool) -> None:
        """
        Update does nothing since the agent does not learn.
        """
        pass

class StaticPolicyAgent(Agent):
    """
    Agent representing a static (not learning) policy.

    :param policy: Policy representing the agent's responses to observations
    """

    def __init__(self, policy: ActorCriticPolicy):
        self.policy = policy

    def get_action(self, obs: Observation, record: bool = True, instruction: np.ndarray = None) -> np.ndarray:
        """
        Return an action given an observation.

        :param obs: The observation to use
        :param record: Whether to record the obs, action (unused)
        :returns: The action to take
        """
        actions, _, log_probs = action_from_policy(obs.obs, self.policy, instruction)
        return clip_actions(actions, self.policy)[0]
    
    def get_action_probs(self, obs: Observation, record: bool = True) -> np.ndarray:
        _, _, log_probs = action_from_policy(obs.obs, self.policy)
        return np.exp(log_probs.cpu().numpy())

    def update(self, reward: float, done: bool) -> None:
        """
        Update does nothing since the agent does not learn.
        """
        pass


class OnPolicyAgent(Agent):
    """
    Agent representing an on-policy learning algorithm (ex: A2C/PPO).

    The `get_action` and `update` functions are based on the `learn` function
    from ``OnPolicyAlgorithm``.

    :param model: Model representing the agent's learning algorithm
    """

    def __init__(self,
                 model: OnPolicyAlgorithm,
                 log_interval=None,
                 tensorboard_log=None,
                 tb_log_name="OnPolicyAgent"):
        self.model = model
        self._last_episode_starts = [True]
        self.n_steps = 0
        self.values: th.Tensor = th.empty(0)

        self.model.set_logger(configure_logger(
            self.model.verbose, tensorboard_log, tb_log_name))

        self.name = tb_log_name
        self.num_timesteps = 0
        self.log_interval = log_interval or (1 if model.verbose else None)
        self.iteration = 0
        self.model.ep_info_buffer = deque([{"r": 0, "l": 0}], maxlen=100)

    def get_action(self, obs: Observation, record: bool = True) -> np.ndarray:
        """
        Return an action given an observation.

        When `record` is True, the agent saves the last transition into its
        buffer. It also updates the model if the buffer is full.

        :param obs: The observation to use
        :param record: Whether to record the obs, action (True when training)
        :returns: The action to take
        """
        obs = obs.obs
        buf = self.model.rollout_buffer

        # train the model if the buffer is full
        if record and self.n_steps >= self.model.n_steps:
            buf.compute_returns_and_advantage(
                last_values=self.values,
                dones=self._last_episode_starts[0]
            )

            if self.log_interval is not None and \
                    self.iteration % self.log_interval == 0:
                self.model.logger.record(
                    "name", self.name, exclude="tensorboard")
                self.model.logger.record(
                    "time/iterations", self.iteration, exclude="tensorboard")

                if len(self.model.ep_info_buffer) > 0 and \
                        len(self.model.ep_info_buffer[0]) > 0:
                    last_exclude = self.model.ep_info_buffer.pop()
                    rews = [ep["r"] for ep in self.model.ep_info_buffer]
                    lens = [ep["l"] for ep in self.model.ep_info_buffer]
                    self.model.logger.record(
                        "rollout/ep_rew_mean", safe_mean(rews))
                    self.model.logger.record(
                        "rollout/ep_len_mean", safe_mean(lens))
                    self.model.ep_info_buffer.append(last_exclude)

                self.model.logger.record(
                    "time/total_timesteps", self.num_timesteps,
                    exclude="tensorboard")
                self.model.logger.dump(step=self.num_timesteps)

            self.model.train()
            self.iteration += 1
            buf.reset()
            self.n_steps = 0

        resample_noise(self.model, self.n_steps)

        actions, values, log_probs = action_from_policy(obs, self.model.policy)

        # modify the rollout buffer with newest info
        if record:
            lastinfo = self.model.ep_info_buffer.pop()
            lastinfo["l"] += 1
            self.model.ep_info_buffer.append(lastinfo)

            obs_shape = self.model.policy.observation_space.shape
            act_shape = self.model.policy.action_space.shape
            buf.add(
                np.reshape(obs, (1,) + obs_shape),
                np.reshape(actions, (1,) + act_shape),
                [0],
                self._last_episode_starts,
                values,
                log_probs
            )

        self.n_steps += 1
        self.num_timesteps += 1
        self.values = values
        return clip_actions(actions, self.model)[0]

    def update(self, reward: float, done: bool) -> None:
        """
        Add new rewards and done information.

        The rewards are added to buffer entry corresponding to the most recent
        recorded action.

        :param reward: The reward receieved from the previous action step
        :param done: Whether the game is done
        """
        buf = self.model.rollout_buffer
        self._last_episode_starts = [done]
        buf.rewards[buf.pos - 1][0] += reward
        lastinfo = self.model.ep_info_buffer.pop()
        lastinfo["r"] += reward
        self.model.ep_info_buffer.append(lastinfo)
        if done:
            self.model.ep_info_buffer.append({"r": 0, "l": 0})

    def learn(self, **kwargs) -> None:
        """ Call the model's learn function with the given parameters """
        self.model._custom_logger = False
        self.model.learn(**kwargs)


class OffPolicyAgent(Agent):
    """
    Agent representing an off-policy learning algorithm (ex: DQN/SAC).

    The `get_action` and `update` functions are based on the `learn` function
    from ``OffPolicyAlgorithm``.

    :param model: Model representing the agent's learning algorithm
    """

    def __init__(self,
                 model: OffPolicyAlgorithm,
                 log_interval=None,
                 tensorboard_log=None,
                 tb_log_name="OffPolicyAgent"):
        self.model = model
        self.model.start_time = time.time()
        self.episode_rewards: List[float] = []
        self.total_timesteps: List[int] = []
        self.num_collected_steps = 0
        self.num_collected_episodes = 0
        self.old_reward: float = 0.0
        self.old_done = False
        self.old_info: Dict = {}

        self.episode_reward: float = 0.0
        self.episode_timesteps = 0
        self.n_steps = 0
        self.old_buffer_action = None

        self.log_interval = log_interval or (4 if model.verbose else None)
        self.name = tb_log_name
        self.model.set_logger(configure_logger(
            self.model.verbose, tensorboard_log, tb_log_name))
        self.model.ep_info_buffer = deque([{"r": 0, "l": 0}], maxlen=100)

    def get_action(self, obs: Observation, record: bool = True) -> np.ndarray:
        """
        Return an action given an observation.

        When `record` is True, the agent saves the last transition into its
        buffer.

        :param obs: The observation to use
        :param record: Whether to record the obs, action (True when training)
        :returns: The action to take
        """
        obs = obs.obs
        if record:
            if self.old_buffer_action is not None:
                buf = self.model.replay_buffer
                buf.observations[buf.pos] = np.array(obs).copy()
                self.model._store_transition(
                    buf, self.old_buffer_action, obs, self.old_reward,
                    self.old_done, [self.old_info])

            if self.old_done:
                self.num_collected_episodes += 1
                self.model._episode_num += 1
                self.episode_rewards.append(self.episode_reward)
                self.total_timesteps.append(self.episode_timesteps)

                if self.model.action_noise is not None:
                    self.model.action_noise.reset()
                self.episode_reward = 0.0
                self.episode_timesteps = 0

                if self.log_interval is not None and \
                        self.model._episode_num % self.log_interval == 0:
                    self.model.logger.record(
                        "name", self.name, exclude="tensorboard")
                    self.model.logger.record(
                        "time/episodes", self.model._episode_num,
                        exclude="tensorboard")

                    if len(self.model.ep_info_buffer) > 0 and \
                            len(self.model.ep_info_buffer[0]) > 0:
                        last_exclude = self.model.ep_info_buffer.pop()
                        rews = [ep["r"] for ep in self.model.ep_info_buffer]
                        lens = [ep["l"] for ep in self.model.ep_info_buffer]
                        self.model.logger.record(
                            "rollout/ep_rew_mean", safe_mean(rews))
                        self.model.logger.record(
                            "rollout/ep_len_mean", safe_mean(lens))
                        self.model.ep_info_buffer.append(last_exclude)

                    self.model.logger.record(
                        "time/total_timesteps", self.model.num_timesteps,
                        exclude="tensorboard")
                    self.model.logger.dump(step=self.model.num_timesteps)

        resample_noise(self.model, self.n_steps)

        obs = obs.reshape((-1,) + self.model.policy.observation_space.shape)
        self.model._last_obs = obs

        action, buffer_action = self.model._sample_action(
            self.model.learning_starts, self.model.action_noise)

        self.model.num_timesteps += 1
        self.episode_timesteps += 1
        self.num_collected_steps += 1
        self.n_steps += 1

        self.old_buffer_action = buffer_action
        self.old_reward = 0

        return clip_actions(action, self.model)[0]

    def update(self, reward: float, done: bool) -> None:
        """
        Add new rewards and done information.

        The agent trains when the model determines that it has collected enough
        timesteps.

        :param reward: The reward receieved from the previous action step
        :param done: Whether the game is done
        """
        self.episode_reward += reward

        self.old_done = done
        self.old_reward += reward

        lastinfo = self.model.ep_info_buffer.pop()
        lastinfo["r"] += reward
        if not done:
            lastinfo["l"] += 1
        self.model.ep_info_buffer.append(lastinfo)
        if done:
            self.model.ep_info_buffer.append({"r": 0, "l": 0})

        if should_collect_more_steps(self.model.train_freq,
                                     self.num_collected_steps,
                                     self.num_collected_episodes):
            return

        gradient_steps = self.model.gradient_steps
        if gradient_steps <= 0:
            gradient_steps = self.num_collected_steps

        self.model.train(batch_size=self.model.batch_size,
                         gradient_steps=gradient_steps)

        self.episode_rewards = []
        self.total_timesteps = []
        self.num_collected_steps = 0
        self.num_collected_episodes = 0

    def learn(self, **kwargs) -> None:
        self.model._custom_logger = False
        self.model.learn(**kwargs)


class RecordingAgentWrapper(Agent):
    """
    Wrapper for an agent that records observation-action pairs.

    Users can also use SimultaneousRecorder or TurnBasedRecorder (from
    wrappers.py) to record the transitions in an environment.

    :param realagent: Agent that defines the behaviour of this actor
    """

    def __init__(self, realagent: Agent):
        self.realagent = realagent
        self.allobs: List[np.ndarray] = []
        self.allacts: List[np.ndarray] = []

    def get_action(self, obs: Observation, record: bool = True) -> np.ndarray:
        """
        Return an action given an observation.

        The output is the same as calling `get_action` on the realagent, but
        this wrapper also stores the observation-action pair to a buffer

        :param obs: The observation to use
        :param record: Whether to record the obs, action (True when training)
        :returns: The action to take
        """
        action = self.realagent.get_action(obs, record)
        self.allobs.append(obs.obs)
        self.allacts.append(action)
        return action

    def update(self, reward: float, done: bool) -> None:
        """
        Simply calls the realagent's update function

        :param reward: The reward receieved from the previous action step
        :param done: Whether the game is done
        """
        self.realagent.update(reward, done)

    def get_transitions(self) -> TransitionsMinimal:
        """
        Return the transitions recorded by this agent.

        :returns: A TransitionsMinimal object representing the transitions
        """
        obs = np.array(self.allobs)
        acts = np.array(self.allacts)
        return TransitionsMinimal(obs, acts)
