import math

import numpy as np
import torch
import torch.nn as nn

from .base import Diffusion
from ..utils import get_kwargs


class EDM(Diffusion):
    """
    Improved preconditioning proposed in the paper
    "Elucidating the Design Space of Diffusion-Based Generative Models" (EDM).
    """

    def __init__(self, model: nn.Module, use_fp16=False,
                 sigma_min=0., sigma_max=math.inf, sigma_data=0.5, P_mean=-1.2, P_std=1.2, **kwargs):
        super().__init__(**get_kwargs(**kwargs))
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.sigma_data = sigma_data
        self.P_mean = P_mean
        self.P_std = P_std
        self.use_fp16 = use_fp16

    def forward(self, x, **model_kwargs):
        rnd_normal = torch.randn([x.shape[0]] + [1] * (x.dim() - 1), device=x.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y = x
        n = torch.randn_like(y) * sigma
        D_yn = self._forward(y + n, sigma, **model_kwargs)
        loss = weight * ((D_yn - y) ** 2)
        return loss

    def _forward(self, x, sigma, force_fp32=False, **model_kwargs):
        x = x.to(torch.float32)
        sigma = sigma.to(torch.float32).reshape(-1, *[1] * (x.dim() - 1))

        dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32

        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
        c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
        c_noise = sigma.log() / 4
        F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
        assert F_x.dtype == dtype
        D_x = c_skip * x + c_out * F_x.to(torch.float32)
        return D_x

    def sample(self, latents, randn_like=torch.randn_like,
               num_steps=18, sigma_min=0.002, sigma_max=80., rho=7,
               S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **model_kwargs):
        # Adjust noise levels based on what's supported by the network.
        sigma_min = max(sigma_min, self.sigma_min)
        sigma_max = min(sigma_max, self.sigma_max)

        # Time step discretization.
        step_indices = torch.arange(num_steps, dtype=latents.dtype, device=latents.device)
        t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
                sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
        t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])  # t_N = 0

        # Main sampling loop.
        x_next = latents * t_steps[0]
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):  # 0, ..., N-1
            x_cur = x_next
            # Increase noise temporarily.
            gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
            t_hat = self.round_sigma(t_cur + gamma * t_cur)
            x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
            # Euler step.
            denoised = self._forward(x_hat, t_hat, **model_kwargs).to(torch.float64)
            d_cur = (x_hat - denoised) / t_hat
            x_next = x_hat + (t_next - t_hat) * d_cur

            # Apply 2nd order correction.
            if i < num_steps - 1:
                denoised = self._forward(x_next, t_next, **model_kwargs).to(torch.float64)
                d_prime = (x_next - denoised) / t_next
                x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

        return x_next

    def repaint(self, x, repaint_mask, num_jumps=20, randn_like=torch.randn_like,
                num_steps=18, sigma_min=0.002, sigma_max=80., rho=7,
                S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **model_kwargs):

        # Adjust noise levels based on what's supported by the network.
        sigma_min = max(sigma_min, self.sigma_min)
        sigma_max = min(sigma_max, self.sigma_max)

        # Time step discretization.
        step_indices = torch.arange(num_steps, dtype=x.dtype, device=x.device)
        t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
                sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
        t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])  # t_N = 0

        # Main sampling loop.
        x_next = randn_like(x).to(torch.float64) * t_steps[0]
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):  # 0, ..., N-1
            for jump in range(num_jumps):
                n_prev = randn_like(x) * t_next
                x_known_t_prev = x + n_prev
                x_unknown_t_prev = self.repaint_step(num_steps, i, randn_like, t_cur, t_next, x_next,
                                                     S_churn, S_min, S_max, S_noise, **model_kwargs)

                x_t_prev = x_known_t_prev * ~repaint_mask + x_unknown_t_prev * repaint_mask
                if jump == num_jumps - 1:
                    x_next = x_t_prev  # turn to x_{t-1}
                else:
                    noise = randn_like(x) * (t_cur.pow(2) - t_next.pow(2)).sqrt()
                    x_next = x_t_prev + noise  # new x_t
        return x_next

    def repaint_step(self, num_steps, i, randn_like, t_cur, t_next, x_next,
                     S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **model_kwargs):
        x_cur = x_next
        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = self.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
        # Euler step.
        denoised = self._forward(x_hat, t_hat, **model_kwargs).to(torch.float64)
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur
        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised = self._forward(x_next, t_next, **model_kwargs).to(torch.float64)
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
        return x_next
