import torch
from copy import deepcopy
from k_level_policy_gradients.src.algorithms.agent import Agent
from k_level_policy_gradients.src.approximators.torch_approximator import (
    TorchApproximator,
)


class DDPG(Agent):
    """
    Deep Deterministic Policy Gradient algorithm.
    "Continuous Control with Deep Reinforcement Learning".
    Lillicrap T. P. et al.. 2016.

    """

    def __init__(
        self,
        mdp_info,
        idx_agent,
        policy,
        actor_params,
        critic_params,
        centralized_critic,
        batch_size,
        target_update_frequency,
        tau,
        warmup_replay_size,
        replay_memory,
        grad_norm_clip,
        use_cuda,
        primary_agent,
        use_mixer,
    ):
        """
        Constructor.

        """
        super().__init__(mdp_info, policy, idx_agent)

        self._centralized_critic = centralized_critic
        self._batch_size = batch_size
        self._target_update_frequency = target_update_frequency
        self._tau = tau
        self._warmup_replay_size = warmup_replay_size
        self._grad_norm_clip = grad_norm_clip
        self._use_mixer = use_mixer
        self._use_cuda = use_cuda

        self._replay_memory = replay_memory

        self._n_updates = 0

        self._primary_agent = primary_agent

        target_actor_params = deepcopy(actor_params)
        self.actor_approximator = TorchApproximator(**actor_params)
        self.target_actor_approximator = TorchApproximator(**target_actor_params)
        target_critic_params = deepcopy(critic_params)
        self.critic_approximator = TorchApproximator(**critic_params)
        self.target_critic_approximator = TorchApproximator(**target_critic_params)
        if primary_agent is None or idx_agent == 0:
            self._update_targets_hard()
        else:
            # Set this agent's actor, target, critic, critic target to be the same as the other agent
            self.actor_approximator.set_primary_approximator(
                primary_agent.actor_approximator
            )
            self.target_actor_approximator.set_primary_approximator(
                primary_agent.target_actor_approximator
            )
            self.critic_approximator.set_primary_approximator(
                primary_agent.critic_approximator
            )
            self.target_critic_approximator.set_primary_approximator(
                primary_agent.target_critic_approximator
            )
        self.policy.set_approximator(self.actor_approximator)
        self._optimizer = self.actor_approximator._optimizer
        self._critic_optimizer = self.critic_approximator._optimizer

        self.actor_params = self.actor_approximator.parameters()
        self.critic_params = self.critic_approximator.parameters()

        self._add_save_attr(
            _batch_size="primitive",
            _target_update_frequency="primitive",
            _tau="primitive",
            _warmup_replay_size="primitive",
            _replay_memory="mushroom!",
            _n_updates="primitive",
            actor_approximator="mushroom",
            target_actor_approximator="mushroom",
            critic_approximator="mushroom",
            target_critic_approximator="mushroom",
            _optimizer="torch",
            _use_mixer="primitive",
            _use_cuda="primitive",
        )

    def draw_action(self, state, action_mask=None):
        """
        Args:
            state (np.ndarray): the state where the agent is.

        Returns:
            The action to be executed.

        """

        return self.policy.draw_action(state)

    def fit(self, dataset):
        if self._use_mixer:
            actor_loss, critic_loss = 0, 0  # storage and fitting handled by mixer
        else:
            own_dataset = self.split_dataset(dataset)
            self._replay_memory.add(own_dataset)
            actor_loss, critic_loss = self._fit()

        self._n_updates += 1
        if self._idx_agent == 0 or self._primary_agent is None:
            self._update_targets_soft()

        return actor_loss, critic_loss

    def split_dataset(self, dataset):
        own_dataset = list()
        for sample in dataset:
            own_sample = {
                "state": sample["state"],
                "obs": sample["obs"][self._idx_agent],
                "action": sample["actions"][self._idx_agent],
                "reward": sample["rewards"][self._idx_agent],
                "next_state": sample["next_state"],
                "next_obs": sample["next_obs"][self._idx_agent],
                "absorbing": sample["absorbing"],
                "last": sample["last"],
            }
            own_dataset.append(own_sample)
        return own_dataset

    def _fit(self):
        if self._replay_memory.size > self._warmup_replay_size:
            _, obs, actions, rewards, _, next_obs, absorbing, _ = (
                self._replay_memory.get(self._batch_size)
            )

            # Convert to torch tensors
            obs_t = torch.tensor(obs, dtype=torch.float32)
            actions_t = torch.tensor(actions, dtype=torch.float32)
            rewards_t = torch.tensor(rewards, dtype=torch.float32)
            next_obs_t = torch.tensor(next_obs, dtype=torch.float32)
            absorbing_t = torch.tensor(absorbing, dtype=torch.bool)

            # Critic update
            q_hat = self.critic_approximator.predict(
                obs_t, actions_t, output_tensor=True
            )
            q_next = self._next_q(next_obs_t)
            q_target = (
                rewards_t + self.mdp_info.gamma * q_next * ~absorbing_t
            ).detach()
            critic_loss = self.critic_approximator._loss(q_hat, q_target)
            self._critic_optimizer.zero_grad()
            critic_loss.backward()
            if self._grad_norm_clip is not None:
                critic_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.critic_params, self._grad_norm_clip
                )
            self.critic_approximator._optimizer.step()

            # Actor update
            actor_actions = self.actor_approximator.predict(obs_t, output_tensor=True)
            q = self.critic_approximator.predict(
                obs_t, actor_actions, output_tensor=True
            )
            loss = -q.mean()
            self._optimizer.zero_grad()
            loss.backward()
            if self._grad_norm_clip is not None:
                actor_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.actor_params, self._grad_norm_clip
                )
            self._optimizer.step()

            return loss.item(), critic_loss.item()
        else:
            return 0, 0

    def _draw_target_action(self, next_state_t):
        """
        Draw an action from the target actor without noise.

        Args:
            next_state_t (torch.Tensor): the state where the action is drawn.

        Returns:
            next_mu (torch.Tensor): the greedy action drawn from the target actor network.
        """

        mu_target = self.target_actor_approximator.predict(
            next_state_t, output_tensor=True
        )

        return mu_target

    def _next_q(self, next_state_t):
        """
        Args:
            next_state (torch.Tensor): the states where next action has to be
                evaluated;
            absorbing (torch.Tensor): the absorbing flag for the states in
                ``next_state``.

        Returns:
            Action-values returned by the critic for ``next_state`` and the
            action returned by the actor.

        """
        a_next_t = self._draw_target_action(next_state_t)
        q_next = self.target_critic_approximator.predict(
            next_state_t, a_next_t, output_tensor=True
        )

        return q_next

    def _update_targets_soft(self):
        """
        Update the target network.

        """
        self._update_target_soft(
            self.actor_approximator, self.target_actor_approximator
        )
        self._update_target_soft(
            self.critic_approximator, self.target_critic_approximator
        )

    def _update_targets_hard(self):
        """
        Update the target network.

        """
        self._update_target_hard(
            self.actor_approximator, self.target_actor_approximator
        )
        self._update_target_hard(
            self.critic_approximator, self.target_critic_approximator
        )

    def _update_target_soft(self, online, target):
        weights = self._tau * online.get_weights()
        weights += (1 - self._tau) * target.get_weights()
        target.set_weights(weights)

    def _update_target_hard(self, online, target):
        target.set_weights(online.get_weights())

    def actor_param(self):
        params = torch.cat([param.flatten() for param in self.actor_params])
        return params

    def actor_gradient(self):
        gradient = torch.cat(
            [
                param.grad.view(-1)
                for param in self.actor_params
                if param.grad is not None
            ]
        )
        return gradient

    def actor_grad_norm(self):
        total_norm = 0.0
        for layer in self.actor_params:
            layer_norm = layer.grad.data.norm(2)
            total_norm += layer_norm.item() ** 2
        total_norm = total_norm**0.5
        return total_norm

    def critic_grad_norm(self):
        total_norm = 0.0
        for layer in self.critic_params:
            layer_norm = layer.grad.data.norm(2)
            total_norm += layer_norm.item() ** 2
        total_norm = total_norm**0.5
        return total_norm
