"""
Diffusion policy with residual PPO.

K: number of denoising steps
To: observation sequence length
Ta: action chunk size
Do: observation dimension
Da: action dimension

C: image channels
H, W: image height and width

"""

import torch.distributions as D
import torch
import logging
from typing import Optional

from model.diffusion.diffusion import DiffusionModel

log = logging.getLogger(__name__)


class ResidualDiffusion(DiffusionModel):

    def __init__(
        self,
        actor,
        critic,
        residual,
        horizon_steps,
        clip_ploss_coef: float,
        clip_vloss_coef: Optional[float] = None,
        norm_adv: Optional[bool] = True,
        residual_scale=0.1,
        network_path=None,
        randn_clip_value=10,
        tanh_output=False,
        **kwargs,
    ):
        super().__init__(
            network=actor,
            network_path=network_path,
            horizon_steps=horizon_steps,
            **kwargs,
        )

        self.actor = self.network
        self.actor_ft = residual.to(self.device)
        self.critic = critic.to(self.device)

        self.horizon_steps = horizon_steps
        # Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
        self.randn_clip_value = randn_clip_value
        # Whether to apply tanh to the **sampled** action --- used in SAC
        self.tanh_output = tanh_output
        # Whether to normalize advantages within batch
        self.norm_adv = norm_adv
        # Clipping value for policy loss
        self.clip_ploss_coef = clip_ploss_coef
        # Clipping value for value loss
        self.clip_vloss_coef = clip_vloss_coef
        # Scale residual actions
        self.residual_scale = residual_scale

        # Turn off gradients for original model
        for param in self.actor.parameters():
            param.requires_grad = False
        logging.info("Turned off gradients of the pretrained network")
        logging.info(f"Number of finetuned parameters: {sum(p.numel() for p in self.actor_ft.parameters() if p.requires_grad)}")

        # Value function
        self.critic = self.critic.to(self.device)
        if network_path is not None:
            checkpoint = torch.load(
                network_path, map_location=self.device, weights_only=True
            )
            if "ema" not in checkpoint:  # load trained RL model
                self.load_state_dict(checkpoint["model"], strict=False)
                logging.info("Loaded critic from %s", network_path)

    @torch.no_grad()
    def forward(
        self,
        cond,
        deterministic=False,
    ):
        """
        Forward pass for sampling actions. Used in evaluating pre-trained/fine-tuned policy. Not modifying diffusion clipping

        Args:
            cond: dict with key state/rgb; more recent obs at the end
                state: (B, To, Do)
                rgb: (B, To, C, H, W)
        """
        base_actions = self.forward_base(cond=cond).trajectories[:, :1, :]
        residual_actions = self.forward_rl(
            cond=cond,
            deterministic=deterministic,
        )
        return base_actions + self.residual_scale * residual_actions

    # ---------- Base Policy ----------#

    @torch.no_grad()
    def forward_base(self, cond):
        return super().forward(cond=cond, deterministic=True)

    def il_loss(self, x, *args):
        return super().loss(x, *args)

    # ---------- RL training ----------#

    def get_residual_obs(self, cond):
        base_action = self.forward_base(cond=cond)
        base_action = base_action.trajectories
        B = len(base_action)
        state = cond["state"][:, -1, :]
        base_action = base_action[:, -1, :]
        return torch.cat([state, base_action], dim=-1).to(self.device)

    def get_value(self, cond):
        return self.critic(self.get_residual_obs(cond))

    def rl_loss(
        self,
        obs,
        actions,
        returns,
        oldvalues,
        advantages,
        oldlogprobs,
        use_bc_loss=False,
    ):
        """
        PPO loss

        obs: dict with key state/rgb; more recent obs at the end
            state: (B, To, Do)
            rgb: (B, To, C, H, W)
        actions: (B, Ta, Da)
        returns: (B, )
        values: (B, )
        advantages: (B,)
        oldlogprobs: (B, )
        """
        newlogprobs, entropy, std = self.get_logprobs(obs, actions)
        newlogprobs = newlogprobs.clamp(min=-5, max=2)
        oldlogprobs = oldlogprobs.clamp(min=-5, max=2)
        entropy_loss = -entropy

        bc_loss = 0.0
        if use_bc_loss:
            # See Eqn. 2 of https://arxiv.org/pdf/2403.03949.pdf
            # Give a reward for maximizing probability of teacher policy's action with current policy.
            # Actions are chosen along trajectory induced by current policy.

            # Get counterfactual teacher actions
            samples = self.forward(
                cond=obs,
                deterministic=False,
                use_base_policy=True,
            )
            # Get logprobs of teacher actions under this policy
            bc_logprobs, _, _ = self.get_logprobs(obs, samples, use_base_policy=False)
            bc_logprobs = bc_logprobs.clamp(min=-5, max=2)
            bc_loss = -bc_logprobs.mean()

        # get ratio
        logratio = newlogprobs - oldlogprobs
        ratio = logratio.exp()

        # get kl difference and whether value clipped
        with torch.no_grad():
            approx_kl = ((ratio - 1) - logratio).nanmean()
            clipfrac = (
                ((ratio - 1.0).abs() > self.clip_ploss_coef).float().mean().item()
            )

        # normalize advantages
        if self.norm_adv:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Policy loss with clipping
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(
            ratio, 1 - self.clip_ploss_coef, 1 + self.clip_ploss_coef
        )
        pg_loss = torch.max(pg_loss1, pg_loss2).mean()

        # Value loss optionally with clipping
        newvalues = self.get_value(obs).view(-1)
        if self.clip_vloss_coef is not None:
            v_loss_unclipped = (newvalues - returns) ** 2
            v_clipped = oldvalues + torch.clamp(
                newvalues - oldvalues,
                -self.clip_vloss_coef,
                self.clip_vloss_coef,
            )
            v_loss_clipped = (v_clipped - returns) ** 2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            v_loss = 0.5 * v_loss_max.mean()
        else:
            v_loss = 0.5 * ((newvalues - returns) ** 2).mean()
        return (
            pg_loss,
            entropy_loss,
            v_loss,
            clipfrac,
            approx_kl.item(),
            ratio.mean().item(),
            bc_loss,
            std.item(),
        )

    def forward_rl_train(
        self,
        cond,
        deterministic=False,
        network_override=None,
    ):
        """
        Calls the MLP to compute the mean, scale, and logits of the GMM. Returns the torch.Distribution object.
        """
        if network_override is not None:
            means, scales = network_override(cond)
        else:
            means, scales = self.actor_ft(cond)
        if deterministic:
            # low-noise for all Gaussian dists
            scales = torch.ones_like(means) * 1e-4
        return D.Normal(loc=means, scale=scales)

    def forward_rl(
        self,
        cond,
        deterministic=False,
        network_override=None,
        reparameterize=False,
        get_logprob=False,
    ):
        cond = self.get_residual_obs(cond=cond)
        B = len(cond)
        T = 1 # 1-step correction
        
        dist = self.forward_rl_train(
            cond,
            deterministic=deterministic,
            network_override=network_override,
        )
        if reparameterize:
            sampled_action = dist.rsample()
        else:
            sampled_action = dist.sample()
        sampled_action.clamp_(
            dist.loc - self.randn_clip_value * dist.scale,
            dist.loc + self.randn_clip_value * dist.scale,
        )

        if get_logprob:
            log_prob = dist.log_prob(sampled_action)

            # For SAC/RLPD, squash mean after sampling here instead of right after model output as in PPO
            if self.tanh_output:
                sampled_action = torch.tanh(sampled_action)
                log_prob -= torch.log(1 - sampled_action.pow(2) + 1e-6)
            return sampled_action.view(B, T, -1), log_prob.sum(1, keepdim=False)
        else:
            if self.tanh_output:
                sampled_action = torch.tanh(sampled_action)
            return sampled_action.view(B, T, -1)

    def get_logprobs(
        self,
        cond,
        actions,
    ):
        cond = self.get_residual_obs(cond=cond)
        B = len(actions)
        dist = self.forward_rl_train(
            cond,
            deterministic=False,
        )
        log_prob = dist.log_prob(actions.view(B, -1))
        # TODO: check if this is correct
        log_prob = log_prob.sum(-1)
        entropy = dist.entropy().mean()
        std = dist.scale.mean()
        return log_prob, entropy, std
