import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
from typing import Optional, List


def edm_sampler(
    net, 
    noise: torch.Tensor, 
    labels: torch.Tensor,
    num_steps: int = 32, 
    sigma_min: float = 0.002, 
    sigma_max: float = 80, 
    rho: int = 7, 
    guidance: float = 1., 
    temperature: float = 1.0,
    S_churn: int = 0, 
    S_min: int = 0, 
    S_max: float = float('inf'), 
    S_noise: int = 1,
    dtype: torch.dtype = torch.float32, 
    net_autoguidance: nn.Module = None,   
    randn_like=torch.randn_like,
):
    # Guided denoiser.
    def denoise(x, t):
        if guidance == 1:
            Dx = net.inference(x, t, labels).to(dtype)
            return Dx
        else:
            if net_autoguidance is not None:
                # autoguidance from Guiding a Diffusion Model with a Bad Version of Itself https://arxiv.org/abs/2406.02507
                ref_Dx = net_autoguidance.inference(x, t, labels).to(dtype)
                return ref_Dx.lerp(Dx, guidance)
            else:
                return net.inference_uncond(x, t, labels, guidance).to(dtype)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=dtype, device=noise.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    x_next = noise.to(dtype) * t_steps[0]
    for i, (t_cur, t_next) in tqdm(enumerate(zip(t_steps[:-1], t_steps[1:])), desc="diffusion sampling...."): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        if S_churn > 0 and S_min <= t_cur <= S_max:
            gamma = min(S_churn / num_steps, np.sqrt(2) - 1)
            t_hat = t_cur + gamma * t_cur
            x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
        else:
            t_hat = t_cur
            x_hat = x_cur

        # Euler step.
        d_cur = (x_hat - denoise(x_hat, t_hat)) / t_hat
        x_next = x_hat + (t_next - t_hat) * (d_cur/temperature)

        # Apply 2nd order correction.
        if i < num_steps - 1:
            d_prime = (x_next - denoise(x_next, t_next)) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * (d_prime/temperature))

    return x_next