import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import pdb

# import diffuser.utils as utils
from helpers import (
    cosine_beta_schedule,
    extract,
    apply_conditioning,
    Losses,
)

class OrdinaryMLP(nn.Module):
    def __init__(self, observation_dim, action_dim, hidden_dim, loss_type='l1'):
        super().__init__()
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.input_dim = observation_dim + action_dim

        self.mid_layer = nn.Sequential(
            nn.Linear(self.input_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
        )

        loss_weights = self.get_loss_weights(1)
        self.loss_fn = Losses[loss_type](loss_weights)

        self.final_layer = nn.Linear(hidden_dim, self.input_dim)

    def get_loss_weights(self, reward_weight):
        loss_weights = torch.ones(self.input_dim, dtype=torch.float32).to(device='cuda')
        loss_weights[-1] = reward_weight
        return loss_weights

    def loss(self, generated_output, ground_truth):
        loss, info = self.loss_fn(generated_output, ground_truth)
        return [loss, info]

    def forward(self, x):
        x = self.mid_layer(x)
        return self.final_layer(x)

class GaussianDiffusion(nn.Module):
    def __init__(self, model, input_dim, condition_dim, n_timesteps=100,
                 loss_type='l1', clip_denoised=False, predict_epsilon=True,
                 reward_weight=1.0, returns_condition=False,
                 condition_guidance_w=0.1,):
        super().__init__()

        self.input_dim = input_dim
        self.condition_dim = condition_dim

        self.model = model
        self.returns_condition = returns_condition
        self.condition_guidance_w = condition_guidance_w
        self.alpha = 0.5  #low temperature sampling for 0.5

        betas = cosine_beta_schedule(n_timesteps, 1e-4)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = torch.cat([torch.ones(1).to(device='cuda'), alphas_cumprod[:-1]])

        self.n_timesteps = int(n_timesteps)
        self.clip_denoised = clip_denoised
        self.predict_epsilon = predict_epsilon

        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        ## log calculation clipped because the posterior variance
        ## is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
            torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.register_buffer('posterior_mean_coef1',
            betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
            (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
        
        ## get loss coefficients and initialize objective
        loss_weights = self.get_loss_weights(reward_weight)
        self.loss_fn = Losses[loss_type](loss_weights)


    def get_loss_weights(self, reward_weight):
        loss_weights = torch.ones(self.input_dim, dtype=torch.float32).to(device='cuda')
        loss_weights[-1] = reward_weight
        loss_weights = loss_weights / torch.sum(loss_weights)
        return loss_weights

    #------------------------------------------ sampling ------------------------------------------#
    def predict_start_from_noise(self,x_t, t, noise):
        '''
            if self.predict_epsilon, model output is (scaled) noise;
            otherwise, model predicts x0 directly
        '''
        if self.predict_epsilon:
            return (
                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
            )
        else:
            return noise
    
    def q_posterior(self, x_start, x_t, t):
        '''
            q(x_{t-1} | x_t, x_0)
        '''
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
    
    def p_mean_variance(self, x, t, conditions):
        '''
            p(x_t | x_{t-1}, c)
        '''
        if self.model.calc_energy:
            assert self.predict_epsilon
            x = torch.tensor(x, requires_grad=True)
            t = torch.tensor(t, dtype=torch.float, requires_grad=True)
            conditions = torch.tensor(conditions, requires_grad=True)

        # epsilon could be epsilon or x0 itself
        epsilon = self.model(x, t, conditions, use_dropout=False)
        
        t = t.detach().to(torch.int64)
        x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon)

        if self.clip_denoised:
            x_recon.clamp_(-1., 1.)
        else:
            assert RuntimeError()
        
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
                x_start=x_recon, x_t=x, t=t)
        
        return model_mean, posterior_variance, posterior_log_variance
    
    @torch.no_grad()
    def p_sample(self, x, t, conditions):
        '''
            p(x_t | x_{t-1}, c)
        '''
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x, t, conditions)
        noise = self.alpha * torch.randn_like(x)
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (self.alpha * model_log_variance).exp() * noise
    
    @torch.no_grad()
    def p_sample_loop(self, shape, conditions, apply_noise=None, verbose=True, return_diffusion=False):
        '''
            generate samples from prior
        '''
        device = self.betas.device
        
        batch_size = shape[0]
        if torch.is_tensor(apply_noise):
            x = self.alpha * apply_noise.clone()
        else:
            x = self.alpha * torch.randn(shape, device=device)

        if return_diffusion: diffusion = [x]

        for i in reversed(range(0, self.n_timesteps)):
            timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(x, timesteps, conditions)

            if return_diffusion: diffusion.append(x)
        
        if return_diffusion:
            return x, torch.stack(diffusion, dim=1)
        else:
            return x
    
    @torch.no_grad()
    def conditional_sample(self, conditions, *args, **kwargs):
        '''
            conditions : [ (state, action), ... ]
        '''
        batch_size = conditions.shape[0] if len(conditions.shape) > 1 else 1
        shape = (batch_size, self.input_dim)

        return self.p_sample_loop(shape, conditions, *args, **kwargs)
    
    def grad_p_sample(self, x, t, conditions):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x, t, conditions)
        noise = self.alpha * torch.randn_like(x)
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (self.alpha * model_log_variance).exp() * noise
    
    def grad_p_sample_loop(self, shape, conditions, verbose=True, return_diffusion=False):
        device = self.betas.device

        batch_size = shape[0]
        x = self.alpha * torch.randn(shape, device=device)

        if return_diffusion: diffusion = [x]

        for i in reversed(range(0, self.n_timesteps)):
            timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.grad_p_sample(x, timesteps, conditions)

            if return_diffusion: diffusion.append(x)


        if return_diffusion:
            return x, torch.stack(diffusion, dim=1)
        else:
            return x
    
    def grad_conditional_sample(self, conditions, *args, **kwargs):
        '''
            conditions : [ (time, state), ... ]
        '''
        device = self.betas.device
        batch_size = len(conditions[0])
        shape = (batch_size, self.input_dim)

        return self.grad_p_sample_loop(shape, conditions, *args, **kwargs)

    #------------------------------------------ training ------------------------------------------#
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sample = (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

        return sample
    
    def p_losses(self, x_start, t, conditions, weights=None, impo_samp=False):
        '''
            loss for p(x_t | x_{t-1}, c)
        '''
        noise = torch.randn_like(x_start)


        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        if self.model.calc_energy:
            assert self.predict_epsilon
            x_noisy.requires_grad = True
            t = torch.tensor(t, dtype=torch.float, requires_grad=True)
            conditions.requires_grad = True
            noise.requires_grad = True
        
        x_recon = self.model(x_noisy, t, conditions)

        assert noise.shape == x_recon.shape

        if self.predict_epsilon:
            loss, info = self.loss_fn(x_recon, noise, weights=weights, impo_samp=impo_samp)
        else:
            loss, info = self.loss_fn(x_recon, x_start)

        return loss, info

    def loss(self, x, conditions, weights=None, impo_samp=False):
        batch_size = len(x)
        t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
        return self.p_losses(x, t, conditions, weights, impo_samp)

    def forward(self, conditions, *args, **kwargs):
        return self.conditional_sample(conditions, *args, **kwargs)
    
    def load(self, path):
        self.load_state_dict(torch.load(path))
