#-*- coding:utf-8 -*-
# 
# original code -> https://github.com/crowsonkb/k-diffusion
# 

import math

import torch
from torch import nn
from tqdm.auto import trange

# from . import sampling, utils

def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
    return x[(...,) + (None,) * dims_to_append]

def append_zero(x):
    return torch.cat([x, x.new_zeros([1])])

def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))

def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cuda'):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)

def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)

@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """DPM-Solver++(2M)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    sigma_fn = lambda t: t.neg().exp()
    t_fn = lambda sigma: sigma.log().neg()
    old_denoised = None

    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
        h = t_next - t
        if old_denoised is None or sigmas[i + 1] == 0:
            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
        else:
            h_last = t - t_fn(sigmas[i - 1])
            r = h_last / h
            denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
        old_denoised = denoised
    return x

@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model(x, sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        dt = sigmas[i + 1] - sigma_hat
        if sigmas[i + 1] == 0:
            # Euler method
            x = x + d * dt
        else:
            # Heun's method
            x_2 = x + d * dt
            denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
    return x

@torch.no_grad()
def calc_dxdt(model, x:torch.Tensor, sigma:torch.Tensor, extra_args:dict):
    D = model(x, sigma, **extra_args)
    F = (x - D) / append_dims(sigma, x.ndim)
    return F

@torch.no_grad()
def calc_dxdt_guidance(model, x:torch.Tensor, sigma:torch.Tensor, extra_args:dict, w:float=7.0):
    c = extra_args['global_cond']
    c0 = extra_args['base_global_cond']
    base_cond = {'global_cond': c0}
    D0 = model(x, sigma, **base_cond)
    F0 = (x - D0) / append_dims(sigma, x.ndim)
    
    cond = {'global_cond': c}
    D = model(x, sigma, **cond)
    F = (x - D) / append_dims(sigma, x.ndim)
    
    F = F0 + w*(F - F0)
    return F

@torch.no_grad()
def sample_heun_simplified(model, x, sigmas, extra_args=None):
    extra_args = {} if extra_args is None else extra_args
    heun_steps = len(sigmas) - 1
    for t_step in range(heun_steps): 
        sigma = sigmas[t_step]
        # D = model(x, sigma, **extra_args)
        # F_euler = (x - D) / append_dims(sigma, x.ndim)
        F_euler = calc_dxdt(model, x, sigma, extra_args)
        
        sigma2 = sigmas[t_step + 1]
        if sigma2 == 0:
            x = x + F_euler * append_dims(sigma2 - sigma, x.ndim)
        else:
            x_euler = x + F_euler * append_dims(sigma2 - sigma, x.ndim)
            
            # D2 = model(x_euler, sigma2, **extra_args)
            # F_heun = (x_euler - D2) / append_dims(sigma2, x.ndim)
            F_heun = calc_dxdt(model, x_euler, sigma2, extra_args)
            x_heun = x + (F_euler + F_heun) * append_dims((sigma2 - sigma) / 2, x.ndim)
            x = x_heun
    return x

@torch.no_grad()
def sample_heun_with_guidance(model, x, sigmas, w, extra_args=None):
    extra_args = {} if extra_args is None else extra_args
    heun_steps = len(sigmas) - 1
    for t_step in range(heun_steps): 
        sigma = sigmas[t_step]
        # D = model(x, sigma, **extra_args)
        # F_euler = (x - D) / append_dims(sigma, x.ndim)
        F_euler = calc_dxdt_guidance(model, x, sigma, extra_args, w=w)
        
        sigma2 = sigmas[t_step + 1]
        if sigma2 == 0:
            x = x + F_euler * append_dims(sigma2 - sigma, x.ndim)
        else:
            x_euler = x + F_euler * append_dims(sigma2 - sigma, x.ndim)
            
            # D2 = model(x_euler, sigma2, **extra_args)
            # F_heun = (x_euler - D2) / append_dims(sigma2, x.ndim)
            F_heun = calc_dxdt_guidance(model, x_euler, sigma2, extra_args, w=w)
            x_heun = x + (F_euler + F_heun) * append_dims((sigma2 - sigma) / 2, x.ndim)
            x = x_heun
    return x


class Denoiser(nn.Module):
    """A Karras et al. preconditioner for denoising diffusion models."""

    def __init__(self, inner_model, sigma_data=1.):
        super().__init__()
        self.inner_model = inner_model
        self.sigma_data = sigma_data

    def get_scalings(self, sigma):
        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
        c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
        return c_skip, c_out, c_in

    def loss(self, input, noise, sigma, mask = None, **kwargs):
        c_skip, c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        noised_input = input + noise * append_dims(sigma, input.ndim)
        if mask is not None:
            noised_input[mask] = input[mask]
        model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
        target = (input - c_skip * noised_input) / c_out
        loss = (model_output - target).pow(2)
        if mask is not None:
            loss_mask = ~mask
            loss = loss * loss_mask.type(loss.dtype)
        return loss.flatten(1).mean(1)

    def forward(self, input, sigma, **kwargs):
        c_skip, c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip


class DenoiserWithVariance(Denoiser):
    def loss(self, input, noise, sigma, **kwargs):
        c_skip, c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        noised_input = input + noise * append_dims(sigma, input.ndim)
        model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
        logvar = append_dims(logvar, model_output.ndim)
        target = (input - c_skip * noised_input) / c_out
        losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
        return losses.flatten(1).mean(1)


class VDenoiser(nn.Module):
    """A v-diffusion-pytorch model wrapper for k-diffusion."""

    def __init__(self, inner_model):
        super().__init__()
        self.inner_model = inner_model # Same model as DDPM
        self.sigma_data = 1.

    def get_scalings(self, sigma):
        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
        c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
        return c_skip, c_out, c_in

    def sigma_to_t(self, sigma):
        return sigma.atan() / math.pi * 2

    def t_to_sigma(self, t):
        return (t * math.pi / 2).tan()

    def loss(self, input, noise, sigma, **kwargs):
        c_skip, c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        noised_input = input + noise * append_dims(sigma, input.ndim)
        model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
        target = (input - c_skip * noised_input) / c_out
        return (model_output - target).pow(2).flatten(1).mean(1)

    def forward(self, input, sigma, **kwargs):
        c_skip, c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip

