# Copyright 2022 Twitter, Inc and Zhendong Wang.
# SPDX-License-Identifier: Apache-2.0

import copy
from debug import debug_print
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


from onpolicy.algorithms.diffusion_ac.helpers import (cosine_beta_schedule,
                                                        linear_beta_schedule,
                                                        vp_beta_schedule,
                                                        extract,
                                                        Losses,
                                                        Progress,
                                                        Silent)

class Diffusion(nn.Module):
    def __init__(self, state_dim, action_dim, model, max_action,
                 beta_schedule='linear', n_timesteps=100,
                 loss_type='l2', clip_denoised=True, predict_epsilon=True, noise_scale=1.0):
        super(Diffusion, self).__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.model = model
        self.noise_scale = noise_scale

        if beta_schedule == 'linear':
            betas = linear_beta_schedule(n_timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(n_timesteps)
        elif beta_schedule == 'vp':
            betas = vp_beta_schedule(n_timesteps)

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])

        self.n_timesteps = int(n_timesteps)
        self.clip_denoised = clip_denoised
        self.predict_epsilon = predict_epsilon
        # self.predict_epsilon = True

        self.register_buffer('betas', betas)
        self.register_buffer('target_betas', betas.clone())
        self.register_buffer('base_betas', torch.ones_like(betas) * 0.7)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        ## log calculation clipped because the posterior variance
        ## is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
                             torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.register_buffer('posterior_mean_coef1',
                             betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
        
        self.register_buffer('eta', torch.zeros(1))

        self.loss_fn = Losses[loss_type]()
        # debug_print('fa', self.posterior_log_variance_clipped)
        # debug_print(self.posterior_log_variance_clipped)
        # exit()

    def update_eta(self, v):
        self.eta.fill_(v)

        betas = self.base_betas * (self.target_betas / self.base_betas).pow(self.eta)

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = torch.cat([torch.ones(1).to(betas), alphas_cumprod[:-1]])

        self.betas.copy_(betas)
        self.alphas_cumprod.copy_(alphas_cumprod)
        self.alphas_cumprod_prev.copy_(alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod.copy_(torch.sqrt(alphas_cumprod))
        self.sqrt_one_minus_alphas_cumprod.copy_(torch.sqrt(1. - alphas_cumprod))
        self.log_one_minus_alphas_cumprod.copy_(torch.log(1. - alphas_cumprod))
        self.sqrt_recip_alphas_cumprod.copy_(torch.sqrt(1. / alphas_cumprod))
        self.sqrt_recipm1_alphas_cumprod.copy_(torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.posterior_variance.copy_(posterior_variance)

        ## log calculation clipped because the posterior variance
        ## is 0 at the beginning of the diffusion chain
        self.posterior_log_variance_clipped.copy_(torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.posterior_mean_coef1.copy_(betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))

        self.posterior_mean_coef2.copy_((1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

    # ------------------------------------------ sampling ------------------------------------------#

    def predict_start_from_noise(self, x_t, t, noise):
        '''
            if self.predict_epsilon, model output is (scaled) noise;
            otherwise, model predicts x0 directly
        '''
        if self.predict_epsilon:
            return (
                    extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
                    extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
            )
        else:
            return noise

    def q_posterior(self, x_start, x_t, t):
        # posterior_mean = (
        #         extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
        #         extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        # )
        # posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        # posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        # return posterior_mean, posterior_variance, posterior_log_variance_clipped

        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)

        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, x, t, s):
        x_recon = self.predict_start_from_noise(x, t=t, noise=self.model(x, t, s))

        if self.clip_denoised:
            x_recon.clamp_(-self.max_action, self.max_action)
        else:
            assert RuntimeError()

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance

    # @torch.no_grad()
    def p_sample(self, x, t, s):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, s=s)
        noise = torch.randn_like(x)#.clamp_(-3, 3)
        # noise = torch.zeros_like(x) # experimental
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        model_variance = (0.5 * model_log_variance).exp() * nonzero_mask 
        noise = nonzero_mask * noise
        log_probs = - 0.5 * noise.pow(2).sum(dim=-1, keepdim=True) * self.noise_scale
        return model_mean + model_variance * noise, log_probs, noise
    
    def p_sample_with_noise(self, x, t, s, noise):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, s=s)
        # noise = torch.zeros_like(x) # experimental
        # noise = torch.randn_like(x)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        model_variance = (0.5 * model_log_variance).exp() * nonzero_mask 
        noise = nonzero_mask * noise
        log_probs = - 0.5 * noise.pow(2).sum(dim=-1, keepdim=True) * self.noise_scale
        return model_mean + model_variance * noise, log_probs
    
    def p_sample_log_prob(self, x, x_, t, s):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, s=s)
        # debug_print(model_log_variance)
        # debug_print(model_log_variance)
        # debug_print(model_log_variance)
        # debug_print(x_-model_mean)
        noise = (x_ - model_mean) / (0.5 * model_log_variance).exp()
        log_probs = - 0.5 * noise.pow(2).sum(dim=-1, keepdim=True)
        return log_probs, 0.5 * model_log_variance.sum(dim=-1, keepdim=True)

    # @torch.no_grad()
    def p_sample_loop(self, state, shape, verbose=False, return_diffusion=False, return_noise=False):
        # self.posterior_log_variance_clipped = torch.log(torch.clamp(self.posterior_variance, min=0.1))
        device = self.betas.device
        
        if return_noise:
            return_diffusion=True

        batch_size = shape[0]
        x = torch.randn(shape, device=device)

        if return_diffusion: 
            diffusion = [x]
            log_probs = []
        
        if return_noise:
            noises = [x]

        progress = Progress(self.n_timesteps) if verbose else Silent()
        for i in reversed(range(0, self.n_timesteps)):
            timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x, log_prob, noise = self.p_sample(x, timesteps, state)

            progress.update({'t': i})

            if return_diffusion: 
                diffusion.append(x)
                log_probs.append(log_prob)
            
            if return_noise:
                noises.append(noise)
                # debug_print(noises[-1].shape)

        progress.close()
        
        if return_noise:
            return x, torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1), torch.stack(noises, dim=1)
        elif return_diffusion:
            return x, torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1)
        else:
            return x
        
    def p_sample_loop_with_noise(self, state, shape, verbose=False, return_diffusion=False, noises=None):
        device = self.betas.device

        batch_size = shape[0]
        # x = torch.randn(shape, device=device) # experimental noises[:, 0]
        # debug_print(noises.shape, state.shape)
        x = noises[:, 0] # experimental 

        if return_diffusion:
            diffusion = [x]
            log_probs = []

        progress = Progress(self.n_timesteps) if verbose else Silent()
        for i in reversed(range(0, self.n_timesteps)):
            timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x, log_prob = self.p_sample_with_noise(x, timesteps, state, noises[:, self.n_timesteps-i])
            
            # x.register_hook(lambda grad, t=i: print(t, grad.mean(), grad.std(), grad.max(), grad.min(), grad.norm()))
            # x.register_hook(lambda grad: grad/(grad.norm() + 1e-6))
            # x.register_hook(lambda grad, t=i: print(t, grad.mean(), grad.std(), grad.max(), grad.min(), grad.norm()))


            progress.update({'t': i})

            if return_diffusion: 
                diffusion.append(x)
                log_probs.append(log_prob)
        progress.close()

        if return_diffusion:
            return x, torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1)
        else:
            return x

    # @torch.no_grad()
    def sample(self, state, *args, **kwargs):
        batch_size = state.shape[0]
        shape = (batch_size, self.action_dim)
        action = self.p_sample_loop(state, shape, *args, **kwargs)
        return action.clamp_(-self.max_action, self.max_action)

    # ------------------------------------------ training ------------------------------------------#

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sample = (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

        return sample

    def p_losses(self, x_start, state, t, weights=1.0):
        noise = torch.randn_like(x_start)

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        x_recon = self.model(x_noisy, t, state)

        assert noise.shape == x_recon.shape
        
        # debug_print(x_recon.shape, x_start.shape, noise.shape)
        
        # x_recon = x_recon.reshape(x_recon.shape[0], 8, -1)
        # x_start = x_start.reshape(x_start.shape[0], 8, -1)
        # x_noisy = x_noisy.reshape(x_noisy.shape[0], 8, -1)
        # r_start = r_start.reshape(r_start.shape[0], 8, -1)




        if self.predict_epsilon:
            loss = F.mse_loss(x_recon, noise, reduction="mean")
        else:
            loss = F.mse_loss(x_recon, x_start, reduction="mean")
        # if loss <= 0.02:
        #     torch.set_printoptions(profile="full", precision=3, sci_mode=False)
        #     debug_print(loss)
        #     debug_print('recon', x_recon-x_noisy)
        #     debug_print('recon', x_recon-x_start)
        return loss


    def loss(self, x, state, weights=1.0):
        batch_size = len(x)
        t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
        return self.p_losses(x, state, t, weights)

    def forward(self, state, *args, **kwargs):
        return self.sample(state, *args, **kwargs)

