import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import pdb
import math
import diffuser.utils as utils
from .helpers import (
    cosine_beta_schedule,
    extract,
    apply_conditioning,
    Losses,
)

@torch.jit.script
def compute_kernel(x, y):
    x_size = x.shape[0]#32
    y_size = y.shape[0]#32
    dim = x.shape[1]#16

    tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1)#(32,32,16)
    tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1)#(32,32,16)

    return torch.exp(-torch.mean((tiled_x - tiled_y)**2, dim=2)/dim*1.0) # (32,32)

@torch.jit.script
def compute_mmd(x, y):#(32,16)(32,16)
    x_kernel = compute_kernel(x, x)#(32,32)
    y_kernel = compute_kernel(y, y)#(32,32)
    xy_kernel = compute_kernel(x, y)#(32,32)
    return torch.mean(x_kernel) + torch.mean(y_kernel) - 2*torch.mean(xy_kernel)

def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)

class GaussianInvDynDiffusion(nn.Module):
    def __init__(self, model, horizon, observation_dim, action_dim, n_timesteps=200,
        loss_type='l1', clip_denoised=True, predict_epsilon=True, hidden_dim=256,
        action_weight=1.0, loss_discount=1.0, loss_weights=None, returns_condition=True,
        condition_guidance_w=1.5, ar_inv=False, train_only_inv=False, info_loss_weight=0.1, repre_type='vec', z_dim=16, pw = 'respective',):
        super().__init__()
        self.horizon = horizon
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.transition_dim = observation_dim + action_dim
        self.model = model # temporalUnet, noise prediction network
        self.train_only_inv = train_only_inv
        self.inv_model = nn.Sequential(
                nn.Linear(2 * self.observation_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, self.action_dim)) # inverse dynamic model, (s, s')->a
        self.returns_condition = returns_condition
        self.condition_guidance_w = condition_guidance_w
        self.info_loss_weight = info_loss_weight
        self.repre_type = repre_type
        self.pw = pw

        self.encoder = nn.Sequential(
                      nn.Linear(horizon * self.observation_dim, 4*16),
                      nn.ReLU(),)
        self.encoder_mean = nn.Linear(4*16, z_dim)
        self.encoder_std = nn.Linear(4*16, z_dim)

        # betas = cosine_beta_schedule(n_timesteps)
        betas = torch.from_numpy(betas_for_alpha_bar(n_timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2)).to(dtype=torch.float32)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])

        self.n_timesteps = int(n_timesteps)
        self.clip_denoised = clip_denoised
        self.predict_epsilon = predict_epsilon

        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        ## log calculation clipped because the posterior variance
        ## is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
            torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.register_buffer('posterior_mean_coef1',
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))

        ## get loss coefficients and initialize objective
        loss_weights = self.get_loss_weights(loss_discount)
        self.loss_fn = Losses['state_l2'](loss_weights)

    def get_loss_weights(self, discount):
        '''
            sets loss coefficients for trajectory

            action_weight   : float
                coefficient on first action loss
            discount   : float
                multiplies t^th timestep of trajectory loss by discount**t
            weights_dict    : dict
                { i: c } multiplies dimension i of observation loss by c
        '''
        self.action_weight = 1
        dim_weights = torch.ones(self.observation_dim, dtype=torch.float32)

        ## decay loss with trajectory timestep: discount**t
        discounts = discount ** torch.arange(self.horizon, dtype=torch.float)
        discounts = discounts / discounts.mean()
        loss_weights = torch.einsum('h,t->ht', discounts, dim_weights)
        # Cause things are conditioned on t=0
        if self.predict_epsilon:
            loss_weights[0, :] = 0

        return loss_weights

    #------------------------------------------ sampling ------------------------------------------#

    def predict_start_from_noise(self, x_t, t, noise):
        '''
            if self.predict_epsilon, model output is (scaled) noise;
            otherwise, model predicts x0 directly
        '''
        if self.predict_epsilon:
            return (
                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
            )
        else:
            return noise

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, x, cond, t, returns=None):
        epsilon_cond = self.model(x, cond, t, returns, use_dropout=False)
        epsilon_uncond = self.model(x, cond, t, returns, force_dropout=True)
        epsilon = epsilon_uncond + self.condition_guidance_w*(epsilon_cond - epsilon_uncond)
        # epsilon = epsilon_uncond

        t = t.detach().to(torch.int64)
        x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon)

        if self.clip_denoised:
            x_recon.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
                x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, x, cond, t, returns=None):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t, returns=returns)
        noise = 0.5*torch.randn_like(x)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    @torch.no_grad()
    def p_sample_loop(self, shape, cond, returns=None, verbose=True, return_diffusion=False):
        device = self.betas.device

        batch_size = shape[0]
        x = 0.5*torch.randn(shape, device=device)
        x = apply_conditioning(x, cond, 0)

        if return_diffusion: diffusion = [x]

        for i in reversed(range(0, self.n_timesteps)):
            timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(x, cond, timesteps, returns)
            x = apply_conditioning(x, cond, 0)

            if return_diffusion: diffusion.append(x)

        if return_diffusion:
            return x, torch.stack(diffusion, dim=1)
        else:
            return x

    @torch.no_grad()
    def conditional_sample(self, cond, returns=None, horizon=None, *args, **kwargs):
        
        shape = (cond.shape[0], self.horizon, self.observation_dim)

        return self.p_sample_loop(shape, cond, returns, *args, **kwargs)
    #------------------------------------------ training ------------------------------------------#

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sample = (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

        return sample

    def p_losses(self, x_start, cond, t, returns=None, phi_std=None):
        '''
        x_start是batch的x的[:, :, action_dim:],也就是[s0,s1,...,s99],[batch_size, horizon, state_dim]
        cond是condition, [batch_size, state_dim]
        t是[batch_size],每个元素在(0,n_timesteps)之间随机取值
        returns是[batch_size, 1]
        '''
        noise = torch.randn_like(x_start) # noise like s_t， [batch_size, horizon, state_dim]

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # 给x_start加上t步的噪声, [batch_size, horizon, state_dim]
        x_noisy = apply_conditioning(x_noisy, cond, 0) # 把整段轨迹的第一步的s0替换成condition

        x_recon = self.model(x_noisy, cond, t, returns) # denosing process? 

        if not self.predict_epsilon:
            x_recon = apply_conditioning(x_recon, cond, 0)
        assert noise.shape == x_recon.shape

        if self.predict_epsilon:
            # loss, info = self.loss_fn(x_recon, noise) # mse loss
            loss = ((noise - x_recon) ** 2).mean()
        else:
            loss, info = self.loss_fn(x_recon, x_start)

        # ind info loss, mutual information between x_0 and w, when the predictor directly predict x_0, instead of noise
        embed = self.encoder(x_recon.reshape(cond.shape[0], -1))
        generated_phi_mean = self.encoder_mean(embed)
        if self.repre_type == 'dist':
            std = self.encoder_std(embed)
            generated_phi_std = torch.clamp(std, min=-5, max=2)
            # print(returns.shape, returns.mean(0).shape, returns.mean(0).unsqueeze(0).repeat(cond.shape[0],1).shape)
            if self.pw == 'average':
                phi_dist = torch.distributions.MultivariateNormal(loc=returns.mean(0).unsqueeze(0).repeat(cond.shape[0],1), 
                                                              covariance_matrix=torch.diag_embed(torch.exp(phi_std.mean(0).unsqueeze(0).repeat(cond.shape[0],1))))
            elif self.pw == 'respective':
                phi_dist = torch.distributions.MultivariateNormal(loc=returns, 
                                                              covariance_matrix=torch.diag_embed(torch.exp(phi_std)))
            elif self.pw == 'gaussian':
                phi_dist = torch.distributions.MultivariateNormal(loc=torch.zeros_like(returns).to(returns.device), 
                                                              covariance_matrix=torch.diag_embed(torch.exp(torch.ones_like(phi_std).to(returns.device))))
            generated_phi_dist = torch.distributions.MultivariateNormal(loc=generated_phi_mean, 
                                                                    covariance_matrix=torch.diag_embed(torch.exp(generated_phi_std)))
            info_loss = torch.distributions.kl_divergence(generated_phi_dist, phi_dist).mean()
            # print(info_loss.shape)
            # info_loss = info_loss.mean(0).sum()
        elif self.repre_type == 'vec':
            info_loss = compute_mmd(returns, generated_phi_mean)
            # set_std = torch.randn(16).to(device=returns.device) * 0.01
            # phi_dist = torch.distributions.MultivariateNormal(loc=returns, 
            #                                                   covariance_matrix=torch.diag_embed(torch.exp(set_std)))
            # generated_phi_dist = torch.distributions.MultivariateNormal(loc=generated_phi_mean,
            #                                                     covariance_matrix=torch.diag_embed(torch.exp(set_std)))
            # info_loss = torch.distributions.kl_divergence(generated_phi_dist, phi_dist).mean()
        if self.repre_type == 'none':
            return loss, 0.0
        else:
            loss += self.info_loss_weight * info_loss
        
        # cond_noise = self.model(x_noisy, cond, t, returns, use_dropout=False)
        # uncond_noise = self.model(x_noisy, cond, t, returns, force_dropout=True)
        # si = torch.nn.functional.cosine_similarity(cond_noise, uncond_noise, dim=-1)

            return loss, info_loss

    def loss(self, x, cond, returns=None, phi_std=None): #  batch =(x, condition, returns)
        '''
        batch这个tuple有三个元素, 第一个的维度是(32, 100, 14), 第二个是dict, dict中是(32, 11)的tensor
        第三个元素是(32, 1)的tensor
        第二个元素condition很像是state的维度
        32是batch size
        100是horizon
        14是state+action的维度,状态11维,动作3维
        '''

        batch_size = len(x)
        t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
        diffuse_loss, info_loss = self.p_losses(x[:, :, self.action_dim:], cond, t, returns, phi_std)
        
        x_t = x[:, :-1, self.action_dim:] # s_t # [s0, s1, ..., s98], [batch_size, horizon-1, state_dim]
        a_t = x[:, :-1, :self.action_dim] # a_t # [a0, a1, ..., a98], [batch_size, horizon-1, action_dim]
        x_t_1 = x[:, 1:, self.action_dim:] # s_t+1 # [s1, s2, ..., s99], [batch_size, horizon-1, state_dim]
        x_comb_t = torch.cat([x_t, x_t_1], dim=-1)
        x_comb_t = x_comb_t.reshape(-1, 2 * self.observation_dim)
        a_t = a_t.reshape(-1, self.action_dim)

        pred_a_t = self.inv_model(x_comb_t) # (s_t, s_t+1) -> a_t
        inv_loss = F.mse_loss(pred_a_t, a_t)

        return diffuse_loss, inv_loss, info_loss

    def forward(self, cond, *args, **kwargs):
        return self.conditional_sample(cond=cond, *args, **kwargs)

    def generate(self, cond, returns):
        shape = (cond.shape[0], self.horizon, self.observation_dim)
        device = self.betas.device
        batch_size = shape[0]
        
        t = torch.randint(0, self.n_timesteps, (batch_size,), device=device).long()
        x = 0.5*torch.randn(shape, device=device)
        x = apply_conditioning(x, cond, 0)
        x = self.model(x, cond, t, returns) # denosing process

        # x = 0.5*torch.randn(shape, device=device)
        # x = apply_conditioning(x, cond, 0)
        # for i in reversed(range(self.n_timesteps-10, self.n_timesteps)):
        #     timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
        #     b, *_, device = *x.shape, x.device
        #     model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=timesteps, returns=returns)
        #     # print('\n', si.detach().max(), si.detach().min(), si.detach().mean())
        #     noise = 0.5*torch.randn_like(x)
        #     # no noise when t == 0
        #     nonzero_mask = (1 - (timesteps == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        #     x =  model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        #     x = apply_conditioning(x, cond, 0)
        
        # timesteps = torch.full((batch_size,), 0, device=device, dtype=torch.long)
        # b, *_, device = *x.shape, x.device
        # model_mean, _, model_log_variance, si = self.p_mean_variance(x=x, cond=cond, t=timesteps, returns=returns)
        # print('\ngenerate: ', si.detach().max(), si.detach().min(), si.detach().mean())
        # noise = 0.5*torch.randn_like(x)
        # # no noise when t == 0
        # nonzero_mask = (1 - (timesteps == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        # x =  model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        # x = apply_conditioning(x, cond, 0)

        embed = self.encoder(x.reshape(cond.shape[0], -1))
        mean = self.encoder_mean(embed)
        std = self.encoder_std(embed)
        std = torch.clamp(std, min=-5, max=2)
        # return mean, si.mean()
        return mean, std

    def dpm_sample(self, cond, returns):
        shape = (cond.shape[0], self.horizon, self.observation_dim)
        device = self.betas.device
        x = 0.5*torch.randn(shape, device=device)
        x = apply_conditioning(x, cond, 0)

        from .dpm_solver_pytorch import DPM_Solver, NoiseScheduleVP
        def wrap_model(model_fn):
            def wrapped_model_fn(x, t):
                t = (t - 1. / ns.total_N) * ns.total_N
                epsilon_cond = model_fn(x, cond, t, returns, use_dropout=False)
                epsilon_uncond = model_fn(x, cond, t, returns, force_dropout=True)
                epsilon = epsilon_uncond + self.condition_guidance_w*(epsilon_cond - epsilon_uncond)
                x_recon = self.predict_start_from_noise(x, t=t.to(torch.int64), noise=epsilon)
                x_recon = torch.clamp(x_recon, -1.0, 1.0)
                model_mean, posterior_variance, model_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t.to(torch.int64))
                noise = 0.5*torch.randn_like(x)
                nonzero_mask = (1 - (t == 0).float()).view(cond.shape[0], *((1,) * (len(x.shape) - 1)))
                x =  model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
                x = apply_conditioning(x, cond, 0)
                # epsilon = epsilon_uncond
                return epsilon
            return wrapped_model_fn
        ns = NoiseScheduleVP(schedule='discrete', betas=self.betas)
        dpm_solver = DPM_Solver(model_fn=wrap_model(self.model), noise_schedule=ns)
        out = dpm_solver.sample(x, steps=20)
        return out

    def p_sample_loop_jit(self, model, cond, returns, horizon, obs_dim, n_timesteps=200):
        shape = (cond.shape[0], horizon, obs_dim)
        x = 0.5*torch.randn(*shape, device=returns.device)
        x = apply_conditioning(x, cond, 0)
        indices = torch.arange(n_timesteps).flip(0).to(device=returns.device)
        i = 0
        
        # print(model.time_mlp)
        # model = torch.jit.trace(model, (x, cond, indices[0], returns))
        # print(model.time_mlp)

        while i < n_timesteps:
            t = indices[i].expand(x.shape[0]).long()
            epsilon_cond = predict(model, x, cond, t, returns, use_dropout=False)
            epsilon_uncond = predict(model, x, cond, t, returns, force_dropout=True)
            # epsilon_cond = model.forward(x, cond, t, returns, use_dropout=False)
            # epsilon_uncond = model.forward(x, cond, t, returns, force_dropout=True)
            epsilon = epsilon_uncond + 1.2*(epsilon_cond - epsilon_uncond)

            x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon)
            x_recon = torch.clamp(x_recon, -1.0, 1.0)
            model_mean, posterior_variance, model_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
            noise = 0.5*torch.randn_like(x)
            nonzero_mask = (1 - (t == 0).float()).view(cond.shape[0], *((1,) * (len(x.shape) - 1)))
            x =  model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
            x = apply_conditioning(x, cond, 0)
            i += 1

        return x