# This files heavily borrows from https://github.com/NVlabs/edm/blob/main/generate.py

from functools import partial

import numpy as np
import torch

from .utils import expand_tensor_dims_as


def edm_sampler_iterative(
    net, latents, guidance_fn=None,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
    net_forward = guidance_fn(net.forward) if guidance_fn is not None else net.forward
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), t_steps.new_zeros(1)]) # t_N = 0

    # Main sampling loop.
    nfe = 0
    x_next = latents.to(torch.float64) * t_steps[0]
    yield x_next
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)

        # Euler step.
        denoised = net_forward(x_hat, t_hat * x_hat.new_ones(len(x_hat))).to(torch.float64)
        nfe += 1
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised = net_forward(x_next, t_next * x_next.new_ones(len(x_next))).to(torch.float64)
            nfe += 1
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
        yield x_next

#----------------------------------------------------------------------------
# Proposed EDM sampler (Algorithm 2).

def edm_sampler(
    net, latents, guidance_fn=None,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
    model_kwargs={}, verbose=False,
):
    net_forward = guidance_fn(net.forward) if guidance_fn is not None else net.forward
    net_forward = partial(net_forward, **model_kwargs)
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    if verbose:
        print(f"Sampling with EDM sampler, {num_steps} steps, sigma_min={sigma_min}, sigma_max={sigma_max}, rho={rho}, S_churn={S_churn}, S_min={S_min}, S_max={S_max}, S_noise={S_noise}")

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), t_steps.new_zeros(1)]) # t_N = 0

    # Main sampling loop.
    nfe = 0
    x_next = latents.to(torch.float64) * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)

        # Euler step.
        denoised = net_forward(x_hat, t_hat * x_hat.new_ones(len(x_hat))).to(torch.float64)
        nfe += 1
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised = net_forward(x_next, t_next * x_next.new_ones(len(x_next))).to(torch.float64)
            nfe += 1
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

    return x_next, nfe

#----------------------------------------------------------------------------
# Generalized ablation sampler, representing the superset of all sampling
# methods discussed in the paper.


def ablation_sampler(
    net, sde, latents, guidance_fn=None, num_steps=18,
    solver='heun', alpha=1, S_churn=0, S_min=0, S_max=float('inf'),
    S_noise=1,
):
    assert solver in ['euler', 'heun']

    net_forward = guidance_fn(net.forward) if guidance_fn is not None else net.forward
    # Adjust noise levels based on what's supported by the network.
    assert sde.sigma_min >= net.sigma_min, f'Network supports sigma_min={net.sigma_min}, but SDE requires sigma_min={sde.sigma_min}'
    assert sde.sigma_max <= net.sigma_max, f'Network supports sigma_max={net.sigma_max}, but SDE requires sigma_max={sde.sigma_max}'

    # Define time steps in terms of noise level.
    sigma_steps = sde.get_sigma_steps(num_steps, device=latents.device)

    # Compute final time steps based on the corresponding noise levels.
    t_steps = sde.sigma_inv(net.round_sigma(sigma_steps))
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    nfe = 0
    t_next = t_steps[0]
    x_next = latents.to(torch.float64) * (sde.sigma(t_next) * sde.scale(t_next))
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sde.sigma(t_cur) <= S_max else 0
        t_hat = sde.sigma_inv(net.round_sigma(sde.sigma(t_cur) + gamma * sde.sigma(t_cur)))
        x_hat = sde.scale(t_hat) / sde.scale(t_cur) * x_cur + (sde.sigma(t_hat) ** 2 - sde.sigma(t_cur) ** 2).clip(min=0).sqrt() * sde.scale(t_hat) * S_noise * torch.randn_like(x_cur)

        # Euler step.
        h = t_next - t_hat
        denoised = net_forward(x_hat / sde.scale(t_hat), sde.sigma(t_hat) * x_hat.new_ones(len(x_hat))).to(torch.float64)
        nfe += 1
        x_coef = (sde.sigma_deriv(t_hat) / sde.sigma(t_hat) + sde.scale_deriv(t_hat) / sde.scale(t_hat))
        denoised_coef = sde.sigma_deriv(t_hat) * sde.scale(t_hat) / sde.sigma(t_hat)
        d_cur = expand_tensor_dims_as(x_coef, x_hat) * x_hat - expand_tensor_dims_as(denoised_coef, denoised) * denoised
        x_prime = x_hat + alpha * h * d_cur
        t_prime = t_hat + alpha * h

        # Apply 2nd order correction.
        if solver == 'euler' or i == num_steps - 1:
            x_next = x_hat + h * d_cur
        else:
            assert solver == 'heun'
            denoised = net_forward(x_prime / sde.scale(t_prime), sde.sigma(t_prime) * x_prime.new_ones(len(x_prime))).to(torch.float64)
            nfe += 1
            x_coef = (sde.sigma_deriv(t_prime) / sde.sigma(t_prime) + sde.scale_deriv(t_prime) / sde.scale(t_prime))
            denoised_coef = sde.sigma_deriv(t_prime) * sde.scale(t_prime) / sde.sigma(t_prime)
            d_prime = expand_tensor_dims_as(x_coef, x_prime) * x_prime - expand_tensor_dims_as(denoised_coef, denoised) * denoised
            x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)

    return x_next, nfe