import torch
import torch.nn as nn
import torch.optim as optim


class PPO(object):
    def __init__(
        self,
        actor_critic,
        clip_param,
        ppo_epoch,
        num_mini_batch,
        value_loss_coef,
        entropy_coef,
        symmetry_coef=0,
        lr=None,
        eps=None,
        max_grad_norm=None,
        use_clipped_value_loss=True,
        mirror_function=None,
    ):
        self.actor_critic = actor_critic

        self.clip_param = clip_param
        self.ppo_epoch = ppo_epoch
        self.num_mini_batch = num_mini_batch
        self.symmetry_coef = symmetry_coef

        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef

        self.max_grad_norm = max_grad_norm
        self.use_clipped_value_loss = use_clipped_value_loss

        self.mirror_function = mirror_function

        self.optimizer = optim.Adam(actor_critic.param_groups(), lr=lr, eps=eps)

    def update(self, rollouts):
        advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)

        value_loss_epoch = 0
        action_loss_epoch = 0
        dist_entropy_epoch = 0
        symmetry_loss_epoch = 0

        for e in range(self.ppo_epoch):
            data_generator = rollouts.feed_forward_generator(
                advantages, self.num_mini_batch
            )

            for sample in data_generator:
                if self.mirror_function is not None and self.symmetry_coef == 0:
                    observations_batch, states_batch, actions_batch, value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ = self.mirror_function(
                        sample
                    )
                else:
                    observations_batch, states_batch, actions_batch, value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ = (
                        sample
                    )

                values, action_log_probs, dist_entropy, states = self.actor_critic.evaluate_actions(
                    observations_batch, states_batch, masks_batch, actions_batch
                )

                ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
                surr1 = ratio * adv_targ
                surr2 = (
                    torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param)
                    * adv_targ
                )
                action_loss = -torch.min(surr1, surr2).mean()

                if self.use_clipped_value_loss:
                    value_pred_clipped = value_preds_batch + (
                        values - value_preds_batch
                    ).clamp(-self.clip_param, self.clip_param)
                    value_losses = (values - return_batch).pow(2)
                    value_losses_clipped = (value_pred_clipped - return_batch).pow(2)
                    value_loss = (
                        0.5 * torch.max(value_losses, value_losses_clipped).mean()
                    )
                else:
                    value_loss = 0.5 * (return_batch - values).pow(2).mean()

                if self.mirror_function is not None and self.symmetry_coef != 0:
                    pi_mean = self.actor_critic.act(
                        observations_batch,
                        states_batch,
                        masks_batch,
                        deterministic=True,
                    )[1]
                    # a bit hacky, I just want to replace the noisy actions with deterministic ones
                    sample = list(sample)
                    sample[2] = pi_mean
                    mirror_obs_batch, _, pi_mean_mirror, *_ = self.mirror_function(
                        sample
                    )
                    pi_mirror_mean = self.actor_critic.act(
                        mirror_obs_batch, states_batch, masks_batch, deterministic=True
                    )[1]
                    sample[2] = actions_batch  # probably not needed
                    symmetry_loss = (pi_mean_mirror - pi_mirror_mean).pow(2).mean()
                else:
                    symmetry_loss = 0

                self.optimizer.zero_grad()
                (
                    value_loss * self.value_loss_coef
                    + action_loss
                    - dist_entropy * self.entropy_coef
                    + symmetry_loss * self.symmetry_coef
                ).backward()
                nn.utils.clip_grad_norm_(
                    self.actor_critic.parameters(), self.max_grad_norm
                )
                self.optimizer.step()

                value_loss_epoch += value_loss.item()
                action_loss_epoch += action_loss.item()
                dist_entropy_epoch += dist_entropy.item()
                if hasattr(symmetry_loss, 'item'):
                    symmetry_loss = symmetry_loss.item()
                symmetry_loss_epoch += symmetry_loss

        num_updates = self.ppo_epoch * self.num_mini_batch

        value_loss_epoch /= num_updates
        action_loss_epoch /= num_updates
        dist_entropy_epoch /= num_updates
        symmetry_loss_epoch /= num_updates

        return value_loss_epoch, action_loss_epoch, dist_entropy_epoch, symmetry_loss_epoch
