from copy import deepcopy
import numpy as np
import torch
from k_level_policy_gradients.src.algorithms.agent import Agent
from k_level_policy_gradients.src.approximators.torch_approximator import (
    TorchApproximator,
)


class DiscreteDDPG(Agent):
    """
    Deep Deterministic Policy Gradient algorithm.
    "Continuous Control with Deep Reinforcement Learning".
    Lillicrap T. P. et al.. 2016.

    """

    def __init__(
        self,
        mdp_info,
        idx_agent,
        policy,
        actor_params,
        critic_params,
        batch_size,
        target_update_frequency,
        tau=None,
        warmup_replay_size=500,
        replay_memory=None,
        epsilon_train=None,
        use_cuda=False,
        primary_agent=None,
        use_mixer=False,
        obs_last_action=False,
    ):
        """
        Constructor.

        """
        super().__init__(mdp_info, policy, idx_agent)

        self._batch_size = batch_size
        self._target_update_frequency = target_update_frequency
        self._tau = tau
        self._warmup_replay_size = warmup_replay_size
        self._epsilon_train = epsilon_train
        self._use_mixer = use_mixer
        self._use_cuda = use_cuda
        self._obs_last_action = obs_last_action

        self._replay_memory = replay_memory

        self._n_updates = 0

        self._primary_agent = primary_agent

        target_actor_params = deepcopy(actor_params)
        self.actor_approximator = TorchApproximator(**actor_params)
        self.target_actor_approximator = TorchApproximator(**target_actor_params)
        target_critic_params = deepcopy(critic_params)
        self.critic_approximator = TorchApproximator(**critic_params)
        self.target_critic_approximator = TorchApproximator(**target_critic_params)
        if primary_agent is None:
            self._update_targets()
        else:
            # Set this agent's actor, target, critic, critic target to be the same as the other agent
            self.actor_approximator.set_primary_approximator(
                primary_agent.actor_approximator
            )
            self.target_actor_approximator.set_primary_approximator(
                primary_agent.target_actor_approximator
            )
            self.critic_approximator.set_primary_approximator(
                primary_agent.critic_approximator
            )
            self.target_critic_approximator.set_primary_approximator(
                primary_agent.target_critic_approximator
            )
        self.policy.set_approximator(self.actor_approximator)
        self._optimizer = self.actor_approximator._optimizer

        self._add_save_attr(
            _batch_size="primitive",
            _target_update_frequency="primitive",
            _tau="mushroom",
            _warmup_replay_size="primitive",
            _epsilon_train="mushroom",
            _replay_memory="mushroom!",
            _n_updates="primitive",
            actor_approximator="mushroom",
            target_actor_approximator="mushroom",
            critic_approximator="mushroom",
            target_critic_approximator="mushroom",
            _optimizer="torch",
            _use_cuda="primitive",
            _use_mixer="primitive",
        )

    def draw_action(self, state, action_mask=None):
        """
        Return the action to execute in the given state. It is the action
        returned by the policy or the action set by the algorithm (e.g. in the
        case of SARSA).

        Args:
            state (np.ndarray): the state where the agent is.
            action_mask (np.ndarray, None): the mask to apply to the action space.

        Returns:
            The action to be executed.

        """

        return self.policy.draw_action(state, action_mask)

    def fit(self, dataset):
        if self._use_mixer:  # storage and fitting handled by mixer
            actor_loss, critic_loss = 0, 0
        else:
            own_dataset = self.split_dataset(dataset)
            self._replay_memory.add(own_dataset)
            actor_loss, critic_loss = self._fit()

        self._n_updates += 1
        if self._n_updates % self._target_update_frequency == 0:
            if self._idx_agent == 0 or self._primary_agent is None:
                self._update_targets()
        return actor_loss, critic_loss

    def split_dataset(self, dataset):
        own_dataset = list()
        for sample in dataset:
            own_sample = {
                "obs": sample["obs"][self._idx_agent],
                "action": sample["actions"][self._idx_agent],
                "action_mask": sample["action_masks"][self._idx_agent],
                "reward": sample["rewards"][self._idx_agent],
                "next_obs": sample["next_obs"][self._idx_agent],
                "next_action_mask": sample["next_action_masks"][self._idx_agent],
                "absorbing": sample["absorbing"],
                "last": sample["last"],
            }
            own_dataset.append(own_sample)
        return own_dataset

    def _fit(self):
        if self._replay_memory.size > self._warmup_replay_size:
            (
                obs,
                actions,
                action_masks,
                rewards,
                next_obs,
                next_action_masks,
                absorbing,
                _,
            ) = self._replay_memory.get(self._batch_size)

            obs_t = torch.tensor(obs, dtype=torch.float32)
            actions_one_hot = np.eye(self.mdp_info.action_space[self._idx_agent].n)[
                actions
            ].squeeze(1)
            actions_one_hot_t = torch.tensor(actions_one_hot, dtype=torch.long)
            action_masks_t = torch.tensor(action_masks, dtype=torch.bool)
            rewards_t = torch.tensor(rewards, dtype=torch.float32)
            next_obs_t = torch.tensor(next_obs, dtype=torch.float32)
            next_action_masks_t = torch.tensor(next_action_masks, dtype=torch.bool)
            absorbings_t = torch.tensor(absorbing, dtype=torch.bool)

            # Critic update
            q_hat = self.critic_approximator.predict(obs_t, actions_one_hot_t)
            q_next = self._next_q(next_obs_t, next_action_masks_t)
            q_next *= ~absorbings_t
            q_target = rewards_t + self.mdp_info.gamma * q_next
            critic_loss = self.critic_approximator._loss(q_hat, q_target)
            self.critic_approximator._optimizer.zero_grad()
            critic_loss.backward()
            self.critic_approximator._optimizer.step()

            # Actor update
            actor_loss = self._loss(obs_t, action_masks_t)
            self._optimizer.zero_grad()
            actor_loss.backward()
            self._optimizer.step()

            self._n_updates += 1

            return actor_loss.item(), critic_loss.item()
        else:
            return 0, 0

    def _loss(self, state, action_mask):
        logits = self.actor_approximator.predict(state, output_tensor=True)
        logits_mask = torch.where(
            action_mask, logits.squeeze(), torch.tensor(float("-inf"))
        )
        action_one_hot = self.policy._gumbel(
            logits_mask, tau=self.policy._tau, hard=False
        )
        q = self.critic_approximator.predict(state, action_one_hot, output_tensor=True)

        return -q.mean()

    def _draw_target_action(self, next_state, next_action_mask):
        """
        Draw an action from the target actor without noise.

        Args:
            state (np.ndarray): the state where the action is drawn.

        Returns:
            mu_true (np.ndarray): the recentralised action drawn from the actor network.
        """
        logits = self.target_actor_approximator.predict(next_state, output_tensor=True)
        logits_mask = torch.where(
            next_action_mask,
            logits.squeeze(),
            torch.tensor(float("-inf")),
        )
        action_one_hot = self.policy._gumbel(
            logits_mask, tau=self.policy._tau, hard=True
        )
        return action_one_hot

    def _next_q(self, next_state, next_action_mask):
        """
        Args:
            next_state (np.ndarray): the states where next action has to be
                evaluated;
            absorbing (np.ndarray): the absorbing flag for the states in
                ``next_state``.

        Returns:
            Action-values returned by the critic for ``next_state`` and the
            action returned by the actor.

        """
        action_one_hot = self._draw_target_action(next_state, next_action_mask)
        q = self.target_critic_approximator.predict(
            next_state, action_one_hot, output_tensor=True
        )

        return q

    def _update_targets(self):
        """
        Update the target network.

        """
        self.target_actor_approximator.set_weights(
            self.actor_approximator.get_weights()
        )
        self.target_critic_approximator.set_weights(
            self.critic_approximator.get_weights()
        )

    def _post_load(self):
        self._actor_approximator = self.policy._approximator
