"""
Diffusion policy gradient with exact likelihood estimation.

Based on score_sde_pytorch https://github.com/yang-song/score_sde_pytorch

To: observation sequence length
Ta: action chunk size
Do: observation dimension
Da: action dimension

"""

import torch
import logging

log = logging.getLogger(__name__)
from .diffusion_ppo import PPODiffusion
from .exact_likelihood import get_likelihood_fn


class PPOExactDiffusion(PPODiffusion):

    def __init__(
        self,
        sde,
        sde_hutchinson_type="Rademacher",
        sde_rtol=1e-4,
        sde_atol=1e-4,
        sde_eps=1e-4,
        sde_step_size=1e-3,
        sde_method="RK23",
        sde_continuous=False,
        sde_probability_flow=False,
        sde_num_epsilon=1,
        sde_min_beta=1e-2,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.sde = sde
        self.sde.set_betas(
            self.betas,
            sde_min_beta,
        )

        # set up likelihood function
        self.likelihood_fn = get_likelihood_fn(
            sde,
            hutchinson_type=sde_hutchinson_type,
            rtol=sde_rtol,
            atol=sde_atol,
            eps=sde_eps,
            step_size=sde_step_size,
            method=sde_method,
            continuous=sde_continuous,
            probability_flow=sde_probability_flow,
            predict_epsilon=self.predict_epsilon,
            num_epsilon=sde_num_epsilon,
        )

    def get_exact_logprobs(self, cond, samples):
        """Use torchdiffeq

        samples: (B x Ta x Da)
        """
        # TODO: image input
        return self.likelihood_fn(
            self.actor,
            self.actor_ft,
            samples,
            self.denoising_steps,
            self.ft_denoising_steps,
            cond=cond,
        )

    def loss(
        self,
        obs,
        samples,
        returns,
        oldvalues,
        advantages,
        oldlogprobs,
        use_bc_loss=False,
        **kwargs,
    ):
        """
        PPO loss

        obs: dict with key state/rgb; more recent obs at the end
            state: (B, To, Do)
        samples: (B, Ta, Da)
        returns: (B, )
        values: (B, )
        advantages: (B,)
        oldlogprobs: (B, )
        """
        # Get new logprobs for final x
        newlogprobs = self.get_exact_logprobs(obs, samples)
        newlogprobs = newlogprobs.clamp(min=-5, max=2)
        oldlogprobs = oldlogprobs.clamp(min=-5, max=2)

        bc_loss = 0
        if use_bc_loss:
            raise NotImplementedError

        # get ratio
        logratio = newlogprobs - oldlogprobs
        ratio = logratio.exp()

        # get kl difference and whether value clipped
        with torch.no_grad():
            # old_approx_kl: the approximate Kullback–Leibler divergence, measured by (-logratio).mean(), which corresponds to the k1 estimator in John Schulman’s blog post on approximating KL http://joschu.net/blog/kl-approx.html
            # approx_kl: better alternative to old_approx_kl measured by (logratio.exp() - 1) - logratio, which corresponds to the k3 estimator in approximating KL http://joschu.net/blog/kl-approx.html
            # old_approx_kl = (-logratio).mean()
            approx_kl = ((ratio - 1) - logratio).mean()
            clipfrac = (
                ((ratio - 1.0).abs() > self.clip_ploss_coef).float().mean().item()
            )

        # normalize advantages
        if self.norm_adv:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Policy loss with clipping
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(
            ratio, 1 - self.clip_ploss_coef, 1 + self.clip_ploss_coef
        )
        pg_loss = torch.max(pg_loss1, pg_loss2).mean()

        # Value loss optionally with clipping
        newvalues = self.critic(obs).view(-1)
        if self.clip_vloss_coef is not None:
            v_loss_unclipped = (newvalues - returns) ** 2
            v_clipped = oldvalues + torch.clamp(
                newvalues - oldvalues,
                -self.clip_vloss_coef,
                self.clip_vloss_coef,
            )
            v_loss_clipped = (v_clipped - returns) ** 2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            v_loss = 0.5 * v_loss_max.mean()
        else:
            v_loss = 0.5 * ((newvalues - returns) ** 2).mean()

        # entropy is maximized - only effective if residual is learned
        return (
            pg_loss,
            v_loss,
            clipfrac,
            approx_kl.item(),
            ratio.mean().item(),
            bc_loss,
        )
