import os
import numpy as np
import torch
import torch.nn.functional as F

def distance_point_to_line(traj, ch, r, bs=1, mode=None):                           # traj.shape = (bs*(num_steps+1), ch, r, r)
    traj = traj.reshape(bs, -1, ch, r, r)
    a, b, c = traj[:, 1:-1], traj[:, 0].unsqueeze(1), traj[:, -1].unsqueeze(1)
    # calc line vectors
    ac = c - a                                                                      # (bs, num_steps-1, ch, r, r)
    bc = c - b                                                                      # (bs, 1, ch, r, r)
    bc_unit = bc / torch.norm(bc, p=2, dim=(1, 2, 3, 4)).reshape(bs, 1, 1, 1, 1)    # (bs, 1, ch, r, r)
    
    # cal projection vector
    bc_unit_bcasted = bc_unit.expand_as(ac)                                         # (bs, num_steps-1, ch, r, r)
    temp = torch.sum(ac * bc_unit_bcasted, dim=(2, 3, 4))                           # (bs, num_steps-1,)
    temp_expanded = temp.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, 1, ch, r, r)  # (bs, num_steps-1, ch, r, r)
    ac_projection = temp_expanded * bc_unit
    
    # cal the length of the vertical
    perp = ac - ac_projection                                                       # (bs, num_steps-1, ch, r, r)
    if mode == 'stat':
        norm = torch.norm(perp, p=2, dim=(2, 3, 4))                                 # (bs, num_steps-1,)
        return torch.mean(norm, dim=0), torch.std(norm, dim=0)                      # (num_steps-1,), (num_steps-1,)
    elif mode == 'norm':
        norm = torch.norm(perp, p=2, dim=(2, 3, 4))
        return norm                                                                 # (num_steps-1,), (num_steps-1,)
    else:
        # default test bs is 1
        return torch.norm(perp, p=2, dim=(2, 3, 4)).reshape(perp.shape[1],)


def get_sampler_settings(
    net, num_steps=18, sigma_min=None, sigma_max=None, rho=7,
    discretization='edm', schedule='edm', scaling='edm',
    epsilon_s=1e-3, alpha=1,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, device=None
):

    assert discretization in ['vp', 've', 'edm']
    assert schedule in ['vp', 've', 'edm']
    assert scaling in ['vp', 've', 'edm']

    solver = 'heun' if scaling == 'edm' else 'euler'

    # Helper functions for VP & VE noise level schedules.
    vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
    vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
    vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
    ve_sigma = lambda t: t.sqrt()
    ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
    ve_sigma_inv = lambda sigma: sigma ** 2

    # Select default noise level range based on the specified time step discretization.
    if sigma_min is None:
        vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
        sigma_min = {'vp': vp_def, 've': 0.02, 'edm': 0.002}[discretization]
    if sigma_max is None:
        vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
        sigma_max = {'vp': vp_def, 've': 100, 'edm': 80}[discretization]

    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Compute corresponding betas for VP.
    vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
    vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d

    # Define time steps in terms of noise level.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
    if discretization == 'vp':
        orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
        sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
    elif discretization == 've':
        orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
        sigma_steps = ve_sigma(orig_t_steps)
    else:
        assert discretization == 'edm'
        sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho

    # Define noise level schedule.
    if schedule == 'vp':
        sigma = vp_sigma(vp_beta_d, vp_beta_min)
        sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
        sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
    elif schedule == 've':
        sigma = ve_sigma
        sigma_deriv = ve_sigma_deriv
        sigma_inv = ve_sigma_inv
    else:
        assert schedule == 'edm'
        sigma = lambda t: t
        sigma_deriv = lambda t: 1
        sigma_inv = lambda sigma: sigma

    # Define scaling schedule.
    if scaling == 'vp':
        s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
        s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
    else:
        s = lambda t: 1
        s_deriv = lambda t: 0

    # Compute final time steps based on the corresponding noise levels.
    t_steps = sigma_inv(torch.as_tensor(sigma_steps))
    return t_steps, sigma, sigma_deriv, sigma_inv, s, s_deriv, solver


def edm_sampler(
    net, latents, class_labels=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, mode=None
):
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    x_next = latents.to(torch.float64) * t_steps[0]
    traj = x_next
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)

        # Euler step.
        denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised_prime = net(x_next, t_next, class_labels).to(torch.float64)
            d_prime = (x_next - denoised_prime) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
        
        if mode == 'trajectory':
            traj = torch.cat((traj, x_next), dim=0)
            if i == 0:
                scores = d_cur
            else:
                scores = torch.cat((scores, d_cur), dim=0)
        elif mode == 'traj_both':
            traj = torch.cat((traj, x_next), dim=0)
            if i == 0:
                traj_de = denoised
            else:
                traj_de = torch.cat((traj_de, denoised), dim=0)
        
    if mode == 'trajectory':
        return traj, t_steps, -1*scores
    elif mode == 'traj_both':
        return traj, traj_de
    else:
        return x_next


def ablation_sampler(
    net, latents, class_labels=None, randn_like=torch.randn_like, num_steps=18, 
    alpha=1, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, mode=None,
    t_steps=None, sigma=None, sigma_deriv=None, sigma_inv=None, s=None, s_deriv=None, solver=None
):
    
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    t_next = t_steps[0]
    x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
    traj = x_next
    norm = torch.norm(x_next, p=2, dim=(1, 2, 3)).unsqueeze(0)
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
        t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
        x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)

        # Euler step.
        h = t_next - t_hat
        denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64)
        d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
        x_prime = x_hat + alpha * h * d_cur
        t_prime = t_hat + alpha * h

        # Apply 2nd order correction.
        if solver == 'euler' or i == num_steps - 1:
            x_next = x_hat + h * d_cur
        else:
            assert solver == 'heun'
            denoised_prime = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64)
            d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised_prime
            x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)

        if mode == 'trajectory':
            traj = torch.cat((traj, x_next), dim=0)
            if i == 0:
                scores = d_cur
            else:
                scores = torch.cat((scores, d_cur), dim=0)
        elif mode == 'norm':
            norm = torch.cat((norm, torch.norm(x_next, p=2, dim=(1, 2, 3)).unsqueeze(0)), dim=0)
        elif mode == 'norm_denoiser':
            norm = torch.cat((norm, torch.norm(x_next, p=2, dim=(1, 2, 3)).unsqueeze(0)), dim=0)
            if i == 0:
                norm_denoiser = torch.norm(denoised, p=2, dim=(1, 2, 3)).unsqueeze(0)
            else:
                norm_denoiser = torch.cat((norm_denoiser, torch.norm(denoised, p=2, dim=(1, 2, 3)).unsqueeze(0)), dim=0)
        elif mode == 'traj_both':
            traj = torch.cat((traj, x_next), dim=0)
            if i == 0:
                traj_de = denoised
            else:
                traj_de = torch.cat((traj_de, denoised), dim=0)
        
    if mode == 'trajectory':
        return traj, t_steps, -1*scores
    elif mode == 'norm':
        return norm
    elif mode == 'norm_denoiser':
        return norm, norm_denoiser
    elif mode == 'traj_both':
        return traj, traj_de
    else:
        return x_next


# (n>=num_steps, ch, r, r) → (num_steps, ch, r, r)
def opt_onetap_sampler(t_steps, images_all, traj, img_channels=3, img_res=32, sigma=None, s=None):
    for i in range(t_steps.shape[0]):
        l2_norm = torch.norm(images_all - traj[i].reshape(1, img_channels, img_res, img_res), p=2, dim=(1, 2, 3))
        noise_level = sigma(t_steps[i]) * torch.ones_like(l2_norm)
        temp = (-1 * l2_norm**2) / (2 * noise_level**2)
        weights = F.softmax(temp)
        opt = torch.sum(torch.mul(images_all, weights.view(-1, 1, 1, 1)), dim=0).unsqueeze(0)
        if i == 0:
            traj_opt = opt
        else:
            traj_opt = torch.cat((traj_opt, opt))
    return traj_opt


# (bs, ch, r, r) → (bs, num_steps+1, ch, r, r)
# images_all.shape = (50000, bs, 3, 32, 32)
def opt_sampler_new(
    latents, t_steps, images_all, img_channels, img_res, randn_like=torch.randn_like, 
    alpha=1, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, mode=None,
    sigma=None, sigma_deriv=None, sigma_inv=None, s=None, s_deriv=None, solver=None
):
    
    batch_size = latents.shape[0]
    num_steps = t_steps.shape[0]
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])])           # t_N = 0
    
    # Main sampling loop.
    t_next = t_steps[0]
    x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
    traj = x_next.unsqueeze(0)
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):    # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
        t_hat = sigma_inv(torch.as_tensor(sigma(t_cur) + gamma * sigma(t_cur)))
        x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)

        # Euler step.
        h = t_next - t_hat
        l2_norm = torch.norm(images_all - x_cur, p=2, dim=(2, 3, 4))   # (50000, bs)
        noise_level = sigma(t_hat) * torch.ones_like(l2_norm)
        temp = (-1 * l2_norm**2) / (2 * noise_level**2)
        weights = F.softmax(temp, dim=0)                                            # (50000, bs)
        denoised = torch.sum(torch.mul(images_all, weights.reshape(-1, batch_size, 1, 1, 1)), dim=0).unsqueeze(0).to(torch.float64)
        d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
        x_prime = x_hat + alpha * h * d_cur
        t_prime = t_hat + alpha * h

        # Apply 2nd order correction.
        if solver == 'euler' or i == num_steps - 1:
            x_next = x_hat + h * d_cur
        else:
            assert solver == 'heun'
            l2_norm = torch.norm(images_all - x_prime, p=2, dim=(2, 3, 4))
            noise_level = sigma(t_prime) * torch.ones_like(l2_norm)
            temp = (-1 * l2_norm**2) / (2 * noise_level**2)
            weights = F.softmax(temp, dim=0)
            denoised = torch.sum(torch.mul(images_all, weights.reshape(-1, batch_size, 1, 1, 1)), dim=0).unsqueeze(0).to(torch.float64)
            d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
            x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
        
        if mode == 'trajectory':
            traj = torch.cat((traj, x_next), dim=1)
        elif mode == 'scores':
            if i == 0:
                scores = d_cur
            else:
                scores = torch.cat((scores, d_cur), dim=0)
                
    if mode == 'trajectory':
        return traj.squeeze(0)
    elif mode == 'scores':
        return -1*scores
    else:
        return x_next


# (1, ch, r, r) → (num_steps+1, ch, r, r)
def opt_sampler(
    latents, t_steps, images_all, img_channels, img_res, randn_like=torch.randn_like, 
    alpha=1, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, mode=None,
    sigma=None, sigma_deriv=None, sigma_inv=None, s=None, s_deriv=None, solver=None
):
    num_steps = t_steps.shape[0]
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])])           # t_N = 0
    
    # Main sampling loop.
    t_next = t_steps[0]
    x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
    traj = x_next
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):    # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
        t_hat = sigma_inv(torch.as_tensor(sigma(t_cur) + gamma * sigma(t_cur)))
        x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)

        # Euler step.
        h = t_next - t_hat
        l2_norm = torch.norm(images_all - x_cur.reshape(1, img_channels, img_res, img_res), p=2, dim=(1, 2, 3))
        noise_level = sigma(t_hat) * torch.ones_like(l2_norm)
        temp = (-1 * l2_norm**2) / (2 * noise_level**2)
        weights = F.softmax(temp)
        denoised = torch.sum(torch.mul(images_all, weights.view(-1, 1, 1, 1)), dim=0).unsqueeze(0).to(torch.float64)
        d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
        x_prime = x_hat + alpha * h * d_cur
        t_prime = t_hat + alpha * h

        # Apply 2nd order correction.
        if solver == 'euler' or i == num_steps - 1:
            x_next = x_hat + h * d_cur
        else:
            assert solver == 'heun'
            l2_norm = torch.norm(images_all - x_prime.reshape(1, img_channels, img_res, img_res), p=2, dim=(1, 2, 3))
            noise_level = sigma(t_prime) * torch.ones_like(l2_norm)
            temp = (-1 * l2_norm**2) / (2 * noise_level**2)
            weights = F.softmax(temp)
            denoised = torch.sum(torch.mul(images_all, weights.view(-1, 1, 1, 1)), dim=0).unsqueeze(0).to(torch.float64)
            d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
            x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
        
        if mode == 'trajectory':
            traj = torch.cat((traj, x_next), dim=0)
        elif mode == 'scores':
            if i == 0:
                scores = d_cur
            else:
                scores = torch.cat((scores, d_cur), dim=0)
                
    if mode == 'trajectory':
        return traj
    elif mode == 'scores':
        return -1*scores
    else:
        return x_next

