"""
Ref: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py
"""

from collections import namedtuple
from functools import partial

import torch
import torch.nn.functional as F
from einops import reduce
from torch import nn
from torch.amp import autocast
from torch.nn import Module

# constants
ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start"])

# helpers functions
expm = torch.linalg.matrix_exp
inv = torch.linalg.inv


def identity(t, *args, **kwargs):
    return t


# model
class MLP(Module):
    def __init__(self, in_dim, out_dim, hidden_dim, T):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )
        self.T = T

    def forward(self, x_t, t):
        shape = x_t.shape

        x_t = x_t.view(shape[0], -1)

        xin = torch.cat([x_t, t.unsqueeze(-1) / self.T], dim=1)

        out = self.net(xin)

        out = out.view(*shape)

        return out


# gaussian diffusion trainer class
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 1e-4
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)


class GaussianDiffusion1D(Module):
    def __init__(
        self,
        model,
        *,
        timesteps=100,
        objective="pred_noise",
    ):
        super().__init__()
        self.model = model
        self.num_timesteps = int(timesteps)
        self.objective = objective

        betas = linear_beta_schedule(timesteps)

        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

        # helper function to register buffer from float64 to float32
        register_buffer = lambda name, val: self.register_buffer(
            name, val.to(torch.float32)
        )

        register_buffer("betas", betas)
        register_buffer("alphas_cumprod", alphas_cumprod)
        register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        register_buffer(
            "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
        )
        register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
        register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
        register_buffer(
            "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
        )

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )

        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        register_buffer("posterior_variance", posterior_variance)

        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        register_buffer(
            "posterior_log_variance_clipped",
            torch.log(posterior_variance.clamp(min=1e-20)),
        )
        register_buffer(
            "posterior_mean_coef1",
            betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
        )
        register_buffer(
            "posterior_mean_coef2",
            (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
        )

        # calculate loss weight
        snr = alphas_cumprod / (1 - alphas_cumprod)

        if objective == "pred_noise":
            loss_weight = torch.ones_like(snr)
        elif objective == "pred_x0":
            loss_weight = snr
        elif objective == "pred_v":
            loss_weight = snr / (snr + 1)

        register_buffer("loss_weight", loss_weight)

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_noise_from_start(self, x_t, t, x0):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0
        ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    def predict_v(self, x_start, t, noise):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
            - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
        )

    def predict_start_from_v(self, x_t, t, v):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
            - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def model_predictions(
        self, x, t, x_self_cond=None, clip_x_start=False, rederive_pred_noise=False
    ):
        model_output = self.model(x, t)
        maybe_clip = (
            partial(torch.clamp, min=-1.0, max=1.0) if clip_x_start else identity
        )

        if self.objective == "pred_noise":
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x, t, pred_noise)
            x_start = maybe_clip(x_start)

            if clip_x_start and rederive_pred_noise:
                pred_noise = self.predict_noise_from_start(x, t, x_start)

        elif self.objective == "pred_x0":
            x_start = model_output
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        elif self.objective == "pred_v":
            v = model_output
            x_start = self.predict_start_from_v(x, t, v)
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        return ModelPrediction(pred_noise, x_start)

    def p_mean_variance(self, x, t, x_self_cond=None, clip_denoised=True):
        preds = self.model_predictions(x, t, x_self_cond)
        x_start = preds.pred_x_start

        if clip_denoised:
            x_start.clamp_(-1.0, 1.0)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start=x_start, x_t=x, t=t
        )
        return model_mean, posterior_variance, posterior_log_variance, x_start

    @torch.no_grad()
    def p_sample(self, x, t: int, x_self_cond=None, clip_denoised=True):
        b, *_, device = *x.shape, x.device
        batched_times = torch.full((b,), t, device=x.device, dtype=torch.long)
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(
            x=x, t=batched_times, x_self_cond=x_self_cond, clip_denoised=clip_denoised
        )
        noise = torch.randn_like(x) if t > 0 else 0.0  # no noise if t == 0
        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
        return pred_img, x_start

    @torch.no_grad()
    def p_sample_loop(self, batch_size, return_all=False):
        b, device = batch_size, self.betas.device

        img = torch.randn(batch_size, 2, 2, device=device)

        if return_all:
            out = torch.zeros(b, self.num_timesteps + 1, *img.shape[1:]).to(img.device)
            out[:, -1] = img.clone()

        x_start = None

        for t in reversed(range(0, self.num_timesteps)):
            img, x_start = self.p_sample(img, t, None)

            if return_all:
                out[:, t] = img.clone()

        if return_all:
            return out
        else:
            return img

    # @torch.no_grad()
    # def p_sample_loop(self, batch_size, return_all=False):
    #     b = batch_size
    #     device = self.betas.device
    #
    #     x_t = torch.randn(b, 2, 2).to(device)
    #
    #     if return_all:
    #         out = torch.zeros(b, self.num_timesteps + 1, *x_t.shape[1:]).to(x_t.device)
    #         out[:, -1] = x_t.clone()
    #
    #     for t in reversed(range(0, self.num_timesteps)):
    #         # t = T,...,1 -> predict t-1
    #         batched_t = torch.full((b,), t + 1, device=device, dtype=torch.long)
    #         model_out = self.model(x_t, batched_t)
    #
    #         if self.objective == "pred_noise":
    #             pred_noise = model_out
    #             # 1. Predict x_0 from x_t first, then predict x_(t-1)
    #             # x_0 = (
    #             #     extract(torch.sqrt(1.0 / self.alphas_cumprod), t, x_t.shape) * x_t
    #             #     - extract(torch.sqrt(1.0 / self.alphas_cumprod - 1), t, x_t.shape)
    #             #     * pred_noise
    #             # )
    #             # posterior_mean = (
    #             #     extract(
    #             #         torch.sqrt(self.alphas_cumprod_prev) * self.betas, t, x_t.shape
    #             #     )
    #             #     * x_0
    #             # ) + extract(
    #             #     torch.sqrt(self.alphas)
    #             #     * (1.0 - self.alphas_cumprod_prev)
    #             #     / (1.0 - self.alphas_cumprod),
    #             #     t,
    #             #     x_t.shape,
    #             # ) * x_t
    #             # posterior_log_variance_clipped = extract(
    #             #     torch.log(
    #             #         (
    #             #             self.betas
    #             #             * (1.0 - self.alphas_cumprod_prev)
    #             #             / (1.0 - self.alphas_cumprod)
    #             #         ).clamp(min=1e-20)
    #             #     ),
    #             #     t,
    #             #     x_t.shape,
    #             # )
    #             # noise = torch.randn_like(x_t) if t > 0 else 0.0
    #             # x_t = (
    #             #     posterior_mean
    #             #     + torch.exp(0.5 * posterior_log_variance_clipped) * noise
    #             # )
    #
    #             # 2. Predict x_(t-1) directly from x_t
    #             noise = torch.randn_like(x_t) if t > 0 else 0.0
    #             posterior_mean = (1 / (1.0 - self.betas[t])) * (
    #                 x_t - (self.betas[t] / (1.0 - self.alphas_cumprod[t])) * pred_noise
    #             )
    #             x_t = posterior_mean + torch.sqrt(self.betas[t]) * noise
    #
    #         elif self.objective == "pred_x_0":
    #             pred_x_0 = model_out
    #             # TODO how to compute noise from \hat{x_0}?
    #             raise NotImplementedError
    #
    #         if return_all:
    #             out[:, t] = x_t.clone()
    #
    #     if return_all:
    #         return out
    #     else:
    #         return x_t

    @torch.no_grad()
    def sample(self, batch_size, return_all=False):
        return self.p_sample_loop(batch_size, return_all=return_all)

    @autocast("cuda", enabled=False)
    def q_sample(self, x_0, t):
        # noise ~ N(0, I)
        noise = torch.randn_like(x_0)

        # x_t = √ᾱ_t * x_0 + √(1 - ᾱ_t) * N(0, I)
        return (
            extract(torch.sqrt(self.alphas_cumprod), t, x_0.shape) * x_0
            + extract(torch.sqrt(1.0 - self.alphas_cumprod), t, x_0.shape) * noise
        ), noise

    def p_losses(self, x_0, t, noise=None):
        x_t, noise = self.q_sample(x_0=x_0, t=t)

        # predict and take gradient step
        model_out = self.model(x_t, t)

        if self.objective == "pred_noise":
            target = noise
        elif self.objective == "pred_x_0":
            target = x_0
        else:
            raise ValueError(f"Unknown objective {self.objective}")

        loss = F.mse_loss(model_out, target, reduction="none")
        loss = reduce(loss, "b ... -> b", "mean")

        loss = loss * extract(self.loss_weight, t, loss.shape)
        return loss.mean()

    def forward(self, x_0):
        b = x_0.shape[0]

        t = torch.randint(
            0, self.num_timesteps, (b,), device=x_0.device, dtype=torch.int64
        )

        return self.p_losses(x_0, t)
