import numpy as np

import torch
import torch.optim as optim
import torch.autograd.functional as func

from mushroom_rl.algorithms.actor_critic.deep_actor_critic import DeepAC
from mushroom_rl.policy import Policy
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.replay_memory import ReplayMemory
from mushroom_rl.utils.torch import to_float_tensor
from mushroom_rl.utils.parameters import to_parameter

from copy import deepcopy
from itertools import chain


class SACPolicy(Policy):
    """
    Class used to implement the policy used by the Soft Actor-Critic
    algorithm. The policy is a Gaussian policy squashed by a tanh.
    This class implements the compute_action_and_log_prob and the
    compute_action_and_log_prob_t methods, that are fundamental for
    the internals calculations of the SAC algorithm.

    """

    def __init__(
        self,
        mu_approximator,
        sigma_approximator,
        min_a,
        max_a,
        log_std_min,
        log_std_max,
    ):
        """
        Constructor.

        Args:
            mu_approximator (Regressor): a regressor computing mean in given a
                state;
            sigma_approximator (Regressor): a regressor computing the variance
                in given a state;
            min_a (np.ndarray): a vector specifying the minimum action value
                for each component;
            max_a (np.ndarray): a vector specifying the maximum action value
                for each component.
            log_std_min ([float, Parameter]): min value for the policy log std;
            log_std_max ([float, Parameter]): max value for the policy log std.

        """
        self._mu_approximator = mu_approximator
        self._sigma_approximator = sigma_approximator

        self._delta_a = to_float_tensor(0.5 * (max_a - min_a), self.use_cuda)
        self._central_a = to_float_tensor(0.5 * (max_a + min_a), self.use_cuda)

        self._log_std_min = to_parameter(log_std_min)
        self._log_std_max = to_parameter(log_std_max)

        self._eps_log_prob = 1e-6

        use_cuda = self._mu_approximator.model.use_cuda

        if use_cuda:
            self._delta_a = self._delta_a.cuda()
            self._central_a = self._central_a.cuda()

        self._add_save_attr(
            _mu_approximator="mushroom",
            _sigma_approximator="mushroom",
            _delta_a="torch",
            _central_a="torch",
            _log_std_min="mushroom",
            _log_std_max="mushroom",
            _eps_log_prob="primitive",
        )

    def __call__(self, state, action):
        raise NotImplementedError

    def draw_action(self, state):
        return (
            self.compute_action_and_log_prob_t(state, compute_log_prob=False)
            .detach()
            .cpu()
            .numpy()
        )

    def compute_action_and_log_prob(self, state):
        """
        Function that samples actions using the reparametrization trick and
        the log probability for such actions.

        Args:
            state (np.ndarray): the state in which the action is sampled.

        Returns:
            The actions sampled and the log probability as numpy arrays.

        """
        a, log_prob = self.compute_action_and_log_prob_t(state)
        return a.detach().cpu().numpy(), log_prob.detach().cpu().numpy()

    def compute_action_and_log_prob_t(self, state, compute_log_prob=True):
        """
        Function that samples actions using the reparametrization trick and,
        optionally, the log probability for such actions.

        Args:
            state (np.ndarray): the state in which the action is sampled;
            compute_log_prob (bool, True): whether to compute the log
            probability or not.

        Returns:
            The actions sampled and, optionally, the log probability as torch
            tensors.

        """
        dist = self.distribution(state)
        a_raw = dist.rsample()
        a = torch.tanh(a_raw)
        a_true = a * self._delta_a + self._central_a

        if compute_log_prob:
            log_prob = dist.log_prob(a_raw).sum(dim=1)
            log_prob -= torch.log(1.0 - a.pow(2) + self._eps_log_prob).sum(dim=1)
            return a_true, log_prob
        else:
            return a_true

    def distribution(self, state):
        """
        Compute the policy distribution in the given states.

        Args:
            state (np.ndarray): the set of states where the distribution is
                computed.

        Returns:
            The torch distribution for the provided states.

        """
        mu = self._mu_approximator.predict(state, output_tensor=True)
        log_sigma = self._sigma_approximator.predict(state, output_tensor=True)
        # Bound the log_std
        log_sigma = torch.clamp(log_sigma, self._log_std_min(), self._log_std_max())
        return torch.distributions.Normal(mu, log_sigma.exp())

    def entropy(self, state=None):
        """
        Compute the entropy of the policy.

        Args:
            state (np.ndarray): the set of states to consider.

        Returns:
            The value of the entropy of the policy.

        """

        return (
            torch.mean(self.distribution(state).entropy()).detach().cpu().numpy().item()
        )

    def distribution_state(self, state):
        return self.distribution(state)

    def policy_gradient(self, state, action):
        dist = self.distribution(state)
        action_t = torch.tensor(action)
        return func.jacobian(lambda x: dist.log_prob(x).exp(), action_t)

    def reset(self):
        pass

    def set_weights(self, weights):
        """
        Setter.

        Args:
            weights (np.ndarray): the vector of the new weights to be used by
                the policy.

        """
        mu_weights = weights[: self._mu_approximator.weights_size]
        sigma_weights = weights[self._mu_approximator.weights_size :]

        self._mu_approximator.set_weights(mu_weights)
        self._sigma_approximator.set_weights(sigma_weights)

    def get_weights(self):
        """
        Getter.

        Returns:
             The current policy weights.

        """
        mu_weights = self._mu_approximator.get_weights()
        sigma_weights = self._sigma_approximator.get_weights()

        return np.concatenate([mu_weights, sigma_weights])

    @property
    def use_cuda(self):
        """
        True if the policy is using cuda_tensors.
        """
        return self._mu_approximator.model.use_cuda

    def parameters(self):
        """
        Returns the trainable policy parameters, as expected by torch
        optimizers.

        Returns:
            List of parameters to be optimized.

        """
        return chain(
            self._mu_approximator.model.network.parameters(),
            self._sigma_approximator.model.network.parameters(),
        )


class SAC(DeepAC):
    """
    Soft Actor-Critic algorithm.
    "Soft Actor-Critic Algorithms and Applications".
    Haarnoja T. et al.. 2019.

    """

    def __init__(
        self,
        mdp_info,
        idx_agent,
        actor_mu_params,
        actor_sigma_params,
        actor_optimizer,
        critic_params,
        batch_size,
        initial_replay_size,
        max_replay_size,
        warmup_transitions,
        tau,
        lr_alpha,
        log_std_min=-20,
        log_std_max=2,
        target_entropy=None,
        critic_fit_params=None,
    ):
        """
        Constructor.

        Args:
            idx_agent (int): index of agent (0 or 1)
            actor_mu_params (dict): parameters of the actor mean approximator
                to build;
            actor_sigma_params (dict): parameters of the actor sigma
                approximator to build;
            actor_optimizer (dict): parameters to specify the actor
                optimizer algorithm;
            critic_params (dict): parameters of the critic approximator to
                build;
            batch_size ((int, Parameter)): the number of samples in a batch;
            initial_replay_size (int): the number of samples to collect before
                starting the learning;
            max_replay_size (int): the maximum number of samples in the replay
                memory;
            warmup_transitions ([int, Parameter]): number of samples to accumulate in the
                replay memory to start the policy fitting;
            tau ([float, Parameter]): value of coefficient for soft updates;
            lr_alpha ([float, Parameter]): Learning rate for the entropy coefficient;
            log_std_min ([float, Parameter]): Min value for the policy log std;
            log_std_max ([float, Parameter]): Max value for the policy log std;
            target_entropy (float, None): target entropy for the policy, if
                None a default value is computed ;
            critic_fit_params (dict, None): parameters of the fitting algorithm
                of the critic approximator.

        """
        self._critic_fit_params = (
            dict() if critic_fit_params is None else critic_fit_params
        )

        self._batch_size = to_parameter(batch_size)
        self._warmup_transitions = to_parameter(warmup_transitions)
        self._tau = to_parameter(tau)

        if target_entropy is None:
            self._target_entropy = -np.prod(
                mdp_info.action_space[idx_agent].shape
            ).astype(np.float32)
        else:
            self._target_entropy = target_entropy

        self._replay_memory = ReplayMemory(initial_replay_size, max_replay_size)

        if "n_models" in critic_params.keys():
            assert critic_params["n_models"] == 2
        else:
            critic_params["n_models"] = 2

        target_critic_params = deepcopy(critic_params)
        self._critic_approximator = Regressor(TorchApproximator, **critic_params)
        self._target_critic_approximator = Regressor(
            TorchApproximator, **target_critic_params
        )

        actor_mu_approximator = Regressor(TorchApproximator, **actor_mu_params)
        actor_sigma_approximator = Regressor(TorchApproximator, **actor_sigma_params)

        policy = SACPolicy(
            actor_mu_approximator,
            actor_sigma_approximator,
            mdp_info.action_space[idx_agent].low,
            mdp_info.action_space[idx_agent].high,
            log_std_min,
            log_std_max,
        )

        self._init_target(self._critic_approximator, self._target_critic_approximator)

        # self._log_alpha = torch.tensor(0.0, dtype=torch.float32)  # alpha = 1.0
        # self._log_alpha = torch.tensor(-2.30, dtype=torch.float32)  # alpha = 0.1
        self._log_alpha = torch.tensor(-5.30, dtype=torch.float32)  # alpha = 0.005

        if policy.use_cuda:
            self._log_alpha = self._log_alpha.cuda().requires_grad_()
        else:
            self._log_alpha.requires_grad_()

        self._alpha_optim = optim.Adam([self._log_alpha], lr=lr_alpha)

        policy_parameters = chain(
            actor_mu_approximator.model.network.parameters(),
            actor_sigma_approximator.model.network.parameters(),
        )

        self._idx_agent = idx_agent

        self._fit_iteration = 1

        self._add_save_attr(
            _critic_fit_params="pickle",
            _batch_size="mushroom",
            _warmup_transitions="mushroom",
            _tau="mushroom",
            _target_entropy="primitive",
            _replay_memory="mushroom",
            _critic_approximator="mushroom",
            _target_critic_approximator="mushroom",
            _log_alpha="torch",
            _alpha_optim="torch",
            _iteration="primitive",
            _idx_agent="primitive",
        )

        super().__init__(mdp_info, policy, actor_optimizer, policy_parameters)

        # Data collection
        self.temperature_data = []
        self.entropy_data = []
        self.actor_loss_data = []
        self.critic_loss_data = []

    @property
    def idx_agent(self):
        return self._idx_agent

    def split_dataset(self, dataset):
        own_dataset = list()
        for i in range(len(dataset)):
            own_dataset.append(
                (
                    dataset[i][0],
                    dataset[i][1][self._idx_agent],
                    dataset[i][2][self._idx_agent],
                    dataset[i][3],
                    dataset[i][4],
                    dataset[i][5],
                )
            )
        return own_dataset

    def fit(self, dataset):
        own_dataset = self.split_dataset(dataset)
        self._replay_memory.add(own_dataset)
        if self._replay_memory.initialized:
            state, action, reward, next_state, absorbing, _ = self._replay_memory.get(
                self._batch_size()
            )

            # Actor update
            if self._replay_memory.size > self._warmup_transitions():
                action_new, log_prob = self.policy.compute_action_and_log_prob_t(state)
                loss = self._loss(state, action_new, log_prob)
                self._optimize_actor_parameters(loss)
                self._update_alpha(log_prob.detach())

            # Critic update
            q_next = self._next_q(next_state, absorbing)
            q = reward + self.mdp_info.gamma * q_next

            self._critic_approximator.fit(state, action, q, **self._critic_fit_params)

            self._update_target(
                self._critic_approximator, self._target_critic_approximator
            )

        self._fit_iteration += 1

        # Collect data
        if self._replay_memory.initialized:
            if self._replay_memory.size > self._warmup_transitions():
                self.temperature_data.append(self._alpha_np)
                self.entropy_data.append(self.policy.entropy(state))
                self.actor_loss_data.append(loss.detach().cpu().numpy())

            critic_loss = list()
            for i in range(self._critic_approximator.__len__()):
                model_i = self._critic_approximator.__getitem__(i)
                if hasattr(model_i, "loss_fit"):
                    m_loss = model_i.loss_fit
                    if hasattr(m_loss, "squeeze"):
                        m_loss = m_loss.squeeze()
                    critic_loss.append(m_loss)
            self.critic_loss_data.append(critic_loss)

    def _loss(self, state, action_new, log_prob):
        q_0 = self._critic_approximator(state, action_new, output_tensor=True, idx=0)
        q_1 = self._critic_approximator(state, action_new, output_tensor=True, idx=1)

        q = torch.min(q_0, q_1)

        return (self._alpha * log_prob - q).mean()

    def _update_alpha(self, log_prob):
        alpha_loss = -(self._log_alpha * (log_prob + self._target_entropy)).mean()
        self._alpha_optim.zero_grad()
        alpha_loss.backward()
        self._alpha_optim.step()

    def _next_q(self, next_state, absorbing):
        """
        Args:
            next_state (np.ndarray): the states where next action has to be
                evaluated;
            absorbing (np.ndarray): the absorbing flag for the states in
                ``next_state``.

        Returns:
            Action-values returned by the critic for ``next_state`` and the
            action returned by the actor.

        """
        a, log_prob_next = self.policy.compute_action_and_log_prob(next_state)

        q = (
            self._target_critic_approximator.predict(next_state, a, prediction="min")
            - self._alpha_np * log_prob_next
        )
        q *= 1 - absorbing

        return q

    def _post_load(self):
        self._update_optimizer_parameters(self.policy.parameters())

    @property
    def _alpha(self):
        return self._log_alpha.exp()

    @property
    def _alpha_np(self):
        return self._alpha.detach().cpu().numpy()

    def _log_info(self, loss, states):
        if self._logger:
            actor_loss = loss.detach().cpu().numpy()

            critic_loss = list()
            for i in range(self._critic_approximator.__len__()):
                model_i = self._critic_approximator.__getitem__(i)
                if hasattr(model_i, "loss_fit"):
                    m_loss = model_i.loss_fit
                    if hasattr(m_loss, "squeeze"):
                        m_loss = m_loss.squeeze()
                    critic_loss.append(m_loss)
            critic_loss = np.array(critic_loss).squeeze()

            alpha = self._alpha_np
            batch_mean_entropy_per_dimension = self.policy.entropy(states)

            msg = (
                "Fit Iteration {}: \t actor loss: {} \t critic loss: {} \t updated alpha: {} \t "
                "batch mean entropy per action dimension: {} ".format(
                    self._fit_iteration,
                    actor_loss,
                    critic_loss,
                    alpha,
                    batch_mean_entropy_per_dimension,
                )
            )

            self._logger.info(msg)
