"""Proximal Policy Optimization (PPO)."""
import torch

from garage.torch.optimizers import OptimizerWrapper

from learning.utils import DnCOptimizerWrapper
from learning.algorithms.dnc_vpg import DnCVPG


class DnCPPO(DnCVPG):
    """Proximal Policy Optimization (PPO).

    Args:
        env_spec (EnvSpec): Environment specification.
        policy (garage.torch.policies.Policy): Policy.
        value_function (garage.torch.value_functions.ValueFunction): The value
            function.
        sampler (garage.sampler.Sampler): Sampler.
        policy_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer
            for policy.
        vf_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer for
            value function.
        lr_clip_range (float): The limit on the likelihood ratio between
            policies.
        num_train_per_epoch (int): Number of train_once calls per epoch.
        discount (float): Discount.
        gae_lambda (float): Lambda used for generalized advantage
            estimation.
        center_adv (bool): Whether to rescale the advantages
            so that they have mean 0 and standard deviation 1.
        positive_adv (bool): Whether to shift the advantages
            so that they are always positive. When used in
            conjunction with center_adv the advantages will be
            standardized before shifting.
        policy_ent_coeff (float): The coefficient of the policy entropy.
            Setting it to zero would mean no entropy regularization.
        use_softplus_entropy (bool): Whether to estimate the softmax
            distribution of the entropy to prevent the entropy from being
            negative.
        stop_entropy_gradient (bool): Whether to stop the entropy gradient.
        entropy_method (str): A string from: 'max', 'regularized',
            'no_entropy'. The type of entropy method to use. 'max' adds the
            dense entropy to the reward for each time step. 'regularized' adds
            the mean entropy to the surrogate objective. See
            https://arxiv.org/abs/1805.00909 for more details.

    """

    def __init__(
        self,
        env_spec,
        policy,
        policies,
        centroid,
        value_functions,
        sampler,
        policy_optimizers=None,
        centroid_optimizer=None,
        vf_optimizers=None,
        lr_clip_range=2e-1,
        kl_coeff=1.0,
        num_train_per_epoch=1,
        discount=0.99,
        gae_lambda=0.97,
        center_adv=True,
        positive_adv=False,
        policy_ent_coeff=0.0,
        use_softplus_entropy=False,
        stop_entropy_gradient=False,
        entropy_method="no_entropy",
        track_centroid=False,
    ):

        if policy_optimizers is None:
            policy_optimizers = [
                DnCOptimizerWrapper(
                    (torch.optim.Adam, dict(lr=3e-4)),
                    policy,
                    max_optimization_epochs=10,
                    minibatch_size=32,
                )
                for policy in policies
            ]
        if track_centroid and centroid_optimizer is None:
            centroid_optimizer = OptimizerWrapper(
                (torch.optim.Adam, dict(lr=3e-4)),
                centroid,
                max_optimization_epochs=1,
                minibatch_size=1000,
            )

        if vf_optimizers is None:
            vf_optimizers = [
                OptimizerWrapper(
                    (torch.optim.Adam, dict(lr=3e-4)),
                    value_function,
                    max_optimization_epochs=10,
                    minibatch_size=32,
                )
                for value_function in value_functions
            ]

        super().__init__(
            env_spec=env_spec,
            policy=policy,
            policies=policies,
            centroid=centroid,
            value_functions=value_functions,
            sampler=sampler,
            policy_optimizers=policy_optimizers,
            centroid_optimizer=centroid_optimizer,
            vf_optimizers=vf_optimizers,
            kl_coeff=kl_coeff,
            num_train_per_epoch=num_train_per_epoch,
            discount=discount,
            gae_lambda=gae_lambda,
            center_adv=center_adv,
            positive_adv=positive_adv,
            policy_ent_coeff=policy_ent_coeff,
            use_softplus_entropy=use_softplus_entropy,
            stop_entropy_gradient=stop_entropy_gradient,
            entropy_method=entropy_method,
            track_centroid=track_centroid,
        )

        self._lr_clip_range = lr_clip_range

    def _compute_objective(self, advantages, obs, actions, rewards, policy_id):
        r"""Compute objective value.

        Args:
            advantages (torch.Tensor): Advantage value at each step
                with shape :math:`(N \dot [T], )`.
            obs (torch.Tensor): Observation from the environment
                with shape :math:`(N \dot [T], O*)`.
            actions (torch.Tensor): Actions fed to the environment
                with shape :math:`(N \dot [T], A*)`.
            rewards (torch.Tensor): Acquired rewards
                with shape :math:`(N \dot [T], )`.

        Returns:
            torch.Tensor: Calculated objective values
                with shape :math:`(N \dot [T], )`.

        """
        # Compute constraint
        with torch.no_grad():
            old_ll = self._old_policies[policy_id](obs)[0].log_prob(actions)
        new_ll = self.policies[policy_id](obs)[0].log_prob(actions)

        likelihood_ratio = (new_ll - old_ll).exp()

        # Calculate surrogate
        surrogate = likelihood_ratio * advantages

        # Clipping the constraint
        likelihood_ratio_clip = torch.clamp(
            likelihood_ratio, min=1 - self._lr_clip_range, max=1 + self._lr_clip_range
        )

        # Calculate surrotate clip
        surrogate_clip = likelihood_ratio_clip * advantages

        return torch.min(surrogate, surrogate_clip)
