"""
Gaussian diffusion with DDPM and optionally DDIM sampling.

References:
Diffuser: https://github.com/jannerm/diffuser
Diffusion Policy: https://github.com/columbia-ai-robotics/diffusion_policy/blob/main/diffusion_policy/policy/diffusion_unet_lowdim_policy.py
Annotated DDIM/DDPM: https://nn.labml.ai/diffusion/stable_diffusion/sampler/ddpm.html

"""

import logging
from debug import debug_print
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Normal, Independent

log = logging.getLogger(__name__)

from onpolicy.algorithms.diffusion_ac.dppo.sampling import (
    extract,
    cosine_beta_schedule,
    make_timesteps,
)

from collections import namedtuple
Sample = namedtuple("Sample", "trajectories chains")


class DiffusionModel(nn.Module):

    def __init__(
        self,
        obs_dim,
        action_dim,
        network,
        denoising_steps=20,
        predict_epsilon=True,
        network_path=None,
        device="cuda",
        horizon_steps=1,
        # Various clipping
        denoised_clip_value=1.0,
        randn_clip_value=10,
        final_action_clip_value=None,
        eps_clip_value=None,  # DDIM only
        # DDPM parameters
        # DDIM sampling
        use_ddim=False,
        ddim_discretize='uniform',
        ddim_steps=None,
        args=None,
        **kwargs,
    ):
        super().__init__()
        self.device = device
        self.horizon_steps = horizon_steps
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.denoising_steps = int(denoising_steps)
        self.predict_epsilon = predict_epsilon
        self.use_ddim = use_ddim
        self.ddim_steps = ddim_steps

        # Clip noise value at each denoising step
        self.denoised_clip_value = denoised_clip_value

        # Whether to clamp the final sampled action between [-1, 1]
        self.final_action_clip_value = final_action_clip_value

        # For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
        self.randn_clip_value = randn_clip_value

        # Clip epsilon for numerical stability
        self.eps_clip_value = eps_clip_value

        # Set up models
        self.network = network.to(device)
        # debug_print(self.predict_epsilon, self.use_ddim)
        if network_path is not None:
            checkpoint = torch.load(network_path, map_location=device, weights_only=True)
            if "ema" in checkpoint:
                self.load_state_dict(checkpoint["ema"], strict=False)
                print("Path", network_path)
                logging.info("Loaded SL-trained policy from %s", network_path)
            else:
                self.load_state_dict(checkpoint["model"], strict=False)
                logging.info("Loaded RL-trained policy from %s", network_path)
        logging.info(
            f"Number of network parameters: {sum(p.numel() for p in self.parameters())}"
        )

        """
        DDPM parameters

        """
        """
        βₜ
        """
        self.betas = cosine_beta_schedule(denoising_steps).to(device)
        """
        αₜ = 1 - βₜ
        """
        self.alphas = 1.0 - self.betas
        """
        α̅ₜ= ∏ᵗₛ₌₁ αₛ 
        """
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        """
        α̅ₜ₋₁
        """
        self.alphas_cumprod_prev = torch.cat([torch.ones(1).to(device), self.alphas_cumprod[:-1]])
        """
        √ α̅ₜ
        """
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        """
        √ 1-α̅ₜ
        """
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        """
        √ 1\α̅ₜ
        """
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
        """
        √ 1\α̅ₜ-1
        """
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
        """
        β̃ₜ = σₜ² = βₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)
        """
        self.ddpm_var = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.ddpm_logvar_clipped = torch.log(torch.clamp(self.ddpm_var, min=1e-20))
        """
        μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
        """
        self.ddpm_mu_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.ddpm_mu_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)

        """
        DDIM parameters

        In DDIM paper https://arxiv.org/pdf/2010.02502, alpha is alpha_cumprod in DDPM https://arxiv.org/pdf/2102.09672
        """
        if use_ddim:
            assert predict_epsilon, "DDIM requires predicting epsilon for now."
            if ddim_discretize == 'uniform':    # use the HF "leading" style
                step_ratio = self.denoising_steps // ddim_steps
                self.ddim_t = torch.arange(0, ddim_steps, device=self.device) * step_ratio
            else:
                raise 'Unknown discretization method for DDIM.'
            self.ddim_alphas = self.alphas_cumprod[self.ddim_t].clone().to(torch.float32)
            self.ddim_alphas_sqrt = torch.sqrt(self.ddim_alphas)
            self.ddim_alphas_prev = torch.cat([
                torch.tensor([1.]).to(torch.float32).to(self.device), 
                self.alphas_cumprod[self.ddim_t[:-1]]])
            self.ddim_sqrt_one_minus_alphas = (1. - self.ddim_alphas) ** .5

            # Initialize fixed sigmas for inference - eta=0
            ddim_eta = 0
            self.ddim_sigmas = (ddim_eta * \
                    ((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * \
                    (1 - self.ddim_alphas / self.ddim_alphas_prev)) ** .5)

            # Flip all
            self.ddim_t = torch.flip(self.ddim_t, [0])
            self.ddim_alphas = torch.flip(self.ddim_alphas, [0])
            self.ddim_alphas_sqrt = torch.flip(self.ddim_alphas_sqrt, [0])
            self.ddim_alphas_prev = torch.flip(self.ddim_alphas_prev, [0])
            self.ddim_sqrt_one_minus_alphas = torch.flip(self.ddim_sqrt_one_minus_alphas, [0])
            self.ddim_sigmas = torch.flip(self.ddim_sigmas, [0])
        self.eta = torch.zeros(1, device=device)
        self.base_betas = torch.ones_like(self.betas, device=device) * 0.7
        self.target_betas = self.betas.clone().to(device)
        self.log_one_minus_alphas_cumprod = torch.log(1. - self.alphas_cumprod)
        self.args = args
        # self.register_buffer('eta', torch.zeros(1))
        # self.register_buffer('base_betas', torch.ones_like(self.betas) * 0.7)
        # self.register_buffer('target_betas', self.betas.clone())
        # self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - self.alphas_cumprod))
    
    def update_eta(self, v):
        if v < -1e-6:
            return
        self.eta.fill_(v)

        betas = self.base_betas * (self.target_betas / self.base_betas).pow(self.eta)
        debug_print(betas)

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = torch.cat([torch.ones(1).to(betas), alphas_cumprod[:-1]])

        self.betas.copy_(betas)
        self.alphas_cumprod.copy_(alphas_cumprod)
        self.alphas_cumprod_prev.copy_(alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod.copy_(torch.sqrt(alphas_cumprod))
        self.sqrt_one_minus_alphas_cumprod.copy_(torch.sqrt(1. - alphas_cumprod))
        self.log_one_minus_alphas_cumprod.copy_(torch.log(1. - alphas_cumprod))
        self.sqrt_recip_alphas_cumprod.copy_(torch.sqrt(1. / alphas_cumprod))
        self.sqrt_recipm1_alphas_cumprod.copy_(torch.sqrt(1. / alphas_cumprod - 1))
        
        self.ddpm_var = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.ddpm_logvar_clipped = torch.log(torch.clamp(self.ddpm_var, min=0.1))
        """
        μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
        """
        self.ddpm_mu_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.ddpm_mu_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        # posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # self.posterior_variance.copy_(posterior_variance)

        # ## log calculation clipped because the posterior variance
        # ## is 0 at the beginning of the diffusion chain
        # self.posterior_log_variance_clipped.copy_(torch.log(torch.clamp(posterior_variance, min=0.1)))
        # self.posterior_mean_coef1.copy_(betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))

        # self.posterior_mean_coef2.copy_((1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

    # ---------- Sampling ----------#

    def p_mean_var(self, x, t, cond, index=None, network_override=None):
        # debug_print(self.ddpm_var)
        if network_override is not None:
            noise = network_override(x, t, cond=cond)
        else:
            noise = self.network(x, t, cond=cond)
        
        # torch.set_printoptions(profile="full", precision=3, sci_mode=False)
        # print('cond', x, t, cond)
        # print('noise', noise)

        # Predict x_0
        if self.predict_epsilon:
            if self.use_ddim:
                """
                x₀ = (xₜ - √ (1-αₜ) ε )/ √ αₜ
                """
                alpha = extract(self.ddim_alphas, index, x.shape)
                alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
                sqrt_one_minus_alpha = extract(self.ddim_sqrt_one_minus_alphas, index, x.shape)
                x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha ** 0.5)
            else:
                """
                x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
                """
                x_recon = (
                    extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
                    - extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
                )
                # torch.set_printoptions(profile="full", precision=3, sci_mode=False)
                # print(t, extract(self.sqrt_recip_alphas_cumprod, t, x.shape))
                # print(t, extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape))
                # print('recon', x_recon)
        else:   # directly predicting x₀
            x_recon = noise
        if self.denoised_clip_value is not None:
            x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
            if self.use_ddim:
                # re-calculate noise based on clamped x_recon - default to false in HF, but let's use it here
                noise = (x - alpha ** (0.5) * x_recon) / sqrt_one_minus_alpha

        # Clip epsilon for numerical stability in policy gradient - not sure if this is helpful yet, but the value can be huge sometimes. This has no effect if DDPM is used
        if self.use_ddim and self.eps_clip_value is not None:
            noise.clamp_(-self.eps_clip_value, self.eps_clip_value)

        # Get mu
        if self.use_ddim:
            """
            μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε 
            
            eta=0
            """
            sigma = extract(self.ddim_sigmas, index, x.shape)
            dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * noise
            mu = (alpha_prev ** 0.5) * x_recon + dir_xt
            var = sigma ** 2
            logvar = torch.log(var)
        else:
            """
            μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
            """
            mu = (
                extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
                + extract(self.ddpm_mu_coef2, t, x.shape) * x
            )
            logvar = extract(
                self.ddpm_logvar_clipped, t, x.shape
            )
        return mu, logvar

    # @torch.no_grad()
    def p_sample_loop(self, cond, return_diffusion=False, return_noise=False, deterministic=False):
        """
        Forward pass for sampling actions. Used in evaluating pre-trained/fine-tuned policy. Not modifying diffusion clipping

        Args:
            cond: dict with key state/rgb; more recent obs at the end
                state: (B, To, Do)
                rgb: (B, To, C, H, W)
        Return:
            Sample: namedtuple with fields:
                trajectories: (B, Ta, Da)
        """
        device = self.betas.device
        sample_data = cond
        B = len(sample_data)

        # Loop
        x = torch.randn((B, self.horizon_steps, self.action_dim), device=device)
        
        if return_diffusion: 
            diffusion = [x.reshape(B, -1)]
            log_probs = []
        
        if return_noise:
            noises = [x.reshape(B, -1)]
        
        if self.use_ddim:
            t_all = self.ddim_t
        else:
            t_all = list(reversed(range(self.denoising_steps)))
        for i, t in enumerate(t_all):
            t_b = make_timesteps(B, t, device)
            index_b = make_timesteps(B, i, device)
            mean, logvar = self.p_mean_var(
                x=x,
                t=t_b,
                cond=cond,
                index=index_b,
            )
            std = torch.exp(0.5 * logvar)

            # Determine noise level
            if self.use_ddim:
                std = torch.zeros_like(std)
            else:
                if t == 0 and not self.args.use_latent_prob:
                    std = torch.zeros_like(std)
                else:
                    std = torch.clip(std, min=1e-3)
            noise = torch.randn_like(x).clamp_(
                -self.randn_clip_value, self.randn_clip_value
            )
            x = mean + std * noise
            
            
            
            if return_diffusion: 
                diffusion.append(x.reshape(B, -1))
                if t >= 0 and self.args.use_latent_prob:
                    dist = Independent(Normal(mean.reshape(B, -1), std.reshape(B, -1)), 1)
                    # dist1 = Normal(mean.reshape(B, -1), std.reshape(B, -1))
                    # torch.set_printoptions(profile="full", precision=5, sci_mode=False)
                    # debug_print(std[0], dist.log_prob(x.reshape(B, -1))[0], dist1.log_prob(x.reshape(B, -1))[0])
                    # debug_print(x[0]-mean[0], dist.log_prob(x)[0])
                    # debug_print(dist.log_prob(noise).shape, mean.shape, std.shape, noise.shape)
                    # debug_print((x-mean)[0])
                    log_probs.append(dist.log_prob(x.detach().reshape(B, -1)).unsqueeze(-1))
                else:
                    log_probs.append(torch.zeros(B, 1, device=x.device))
            
            if return_noise:
                if self.args.use_latent_prob:
                    noises.append(x.reshape(B, -1))
                else:
                    noises.append(noise.reshape(B, -1))
                # debug_print(noises[-1].shape)

            # clamp action at final step
            # if self.final_action_clip_value is not None and i == len(t_all) - 1:
            #     x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
        # print(x)
        # torch.set_printoptions(profile="full", precision=3, sci_mode=False)
        # print(x.shape)
        # print(x.reshape(x.shape[0], 8, -1))
        if return_noise:
            return x.reshape(B, -1), torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1), torch.stack(noises, dim=1)
        elif return_diffusion:
            return x.reshape(B, -1), torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1)
        else:
            return x.reshape(B, -1)
        
    def p_sample_loop_with_noise(self, cond, return_diffusion=False, return_noise=False, deterministic=False, noises=None, detach=False):
        """
        Forward pass for sampling actions. Used in evaluating pre-trained/fine-tuned policy. Not modifying diffusion clipping

        Args:
            cond: dict with key state/rgb; more recent obs at the end
                state: (B, To, Do)
                rgb: (B, To, C, H, W)
        Return:
            Sample: namedtuple with fields:
                trajectories: (B, Ta, Da)
        """
        device = self.betas.device
        sample_data = cond
        B = len(sample_data)

        # Loop
        # x = torch.randn((B, self.horizon_steps, self.action_dim), device=device)
        # debug_print(noises.shape, B, cond.shape)
        
        x = noises[:, 0].reshape((B, self.horizon_steps, self.action_dim))
        
        if return_diffusion: 
            diffusion = [x.reshape(B, -1)]
            log_probs = []
        
        if return_noise:
            noises = [x.reshape(B, -1)]
        
        if self.use_ddim:
            t_all = self.ddim_t
        else:
            t_all = list(reversed(range(self.denoising_steps)))
        # torch.set_printoptions(profile="full", precision=5, sci_mode=False)
        for i, t in enumerate(t_all):
            if detach:
                x = x.detach()
            t_b = make_timesteps(B, t, device)
            index_b = make_timesteps(B, i, device)
            mean, logvar = self.p_mean_var(
                x=x,
                t=t_b,
                cond=cond,
                index=index_b,
            )
            std = torch.exp(0.5 * logvar)
            # mean.register_hook(lambda x: print(x.norm()))

            # Determine noise level
            if self.use_ddim:
                std = torch.zeros_like(std)
            else:
                if t == 0 and not self.args.use_latent_prob:
                    std = torch.zeros_like(std)
                else:
                    std = torch.clip(std, min=1e-3)
            # noise = torch.randn_like(x).clamp_(
            #     -self.randn_clip_value, self.randn_clip_value
            # )
            noise = noises[:, i+1].reshape(x.shape)
            x = mean + std * noise
            
            
            if return_diffusion: 
                diffusion.append(x.reshape(B, -1))
                if t >= 0 and self.args.use_latent_prob:
                    dist = Independent(Normal(mean.reshape(B, -1), std.reshape(B, -1)), 1)
                    # debug_print(dist.log_prob(noise).shape, mean.shape, std.shape, noise.shape)
                    x = noise
                    # debug_print(dist.log_prob(noise).shape, mean.shape, std.shape, noise.shape)
                    log_probs.append(dist.log_prob(x.reshape(B, -1)).unsqueeze(-1))
                else:
                    log_probs.append(torch.zeros(B, 1, device=x.device))
            
            if return_noise:
                noises.append(noise.reshape(B, -1))

            # clamp action at final step
            # if self.final_action_clip_value is not None and i == len(t_all) - 1:
            #     x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
        if return_noise:
            return x.reshape(B, -1), torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1), torch.stack(noises, dim=1)
        elif return_diffusion:
            return x.reshape(B, -1), torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1)
        else:
            return x.reshape(B, -1)

    # ---------- Supervised training ----------#

    def loss(self, x, *args):
        batch_size = len(x)
        t = torch.randint(
            0, self.denoising_steps, (batch_size,), device=x.device
        ).long()
        return self.p_losses(x, *args, t)

    def p_losses(
        self,
        x_start,
        cond: dict,
        t,
    ):
        """
        If predicting epsilon: E_{t, x0, ε} [||ε - ε_θ(√α̅ₜx0 + √(1-α̅ₜ)ε, t)||²

        Args:
            x_start: (batch_size, horizon_steps, action_dim)
            cond: dict with keys as step and value as observation
            t: batch of integers
        """
        device = x_start.device

        # Forward process
        noise = torch.randn_like(x_start, device=device)
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        # Predict
        x_recon = self.network(x_noisy, t, cond=cond)
        
        # torch.set_printoptions(profile="full", precision=3, sci_mode=False)
        # debug_print('recon', x_recon-x_noisy)
        if self.predict_epsilon:
            return F.mse_loss(x_recon, noise, reduction="mean") 
        else:
            return F.mse_loss(x_recon, x_start, reduction="mean")

    def q_sample(self, x_start, t, noise=None):
        """
        q(xₜ | x₀) = 𝒩(xₜ; √ α̅ₜ x₀, (1-α̅ₜ)I)
        xₜ = √ α̅ₜ xₒ + √ (1-α̅ₜ) ε
        """
        if noise is None:
            device = x_start.device
            noise = torch.randn_like(x_start, device=device)
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
