"""
Ref: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py
"""

from collections import namedtuple
from functools import reduce

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.amp import autocast
from torch.nn import Module

from group_discovery.utils import blogm

# constants

ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start"])

# helpers functions
expm = torch.linalg.matrix_exp
detm = torch.linalg.det


# model
class MLP(Module):
    def __init__(self, in_dim, out_dim, hidden_dim, T):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )
        self.T = T

    def forward(self, x_t, t):
        shape = x_t.shape

        x_t = x_t.view(shape[0], -1)

        xin = torch.cat([x_t, t.unsqueeze(-1) / self.T], dim=1)

        out = self.net(xin)

        out = out.view(-1, 2, 2)

        return out


# gaussian diffusion trainer class
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def linear_schedule(timesteps):
    scale = 500 / timesteps
    beta_start = scale * 1e-4
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)


def geometric_decay_schedule(timesteps):
    # scale = 100 / timesteps
    # beta_start = scale * 1
    # beta_end = scale * timesteps
    # return torch.exp(
    #     -torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
    # )
    return torch.logspace(0, timesteps, timesteps, base=0.4507, dtype=torch.float64)


class GaussianDiffusion1D(Module):
    def __init__(
        self,
        model,
        *,
        distributions_distribution,
        timesteps=200,
        objective="pred_noise",
        scale_targets=False,
    ):
        super().__init__()
        self.model = model
        self.distributions_distribution = distributions_distribution
        self.objective = objective
        self.scale_targets = scale_targets

        # betas = geometric_decay_schedule(timesteps)
        betas = linear_schedule(timesteps)

        (timesteps,) = betas.shape
        self.num_timesteps = int(timesteps)
        self.bins = np.linspace(0, self.num_timesteps, 11)

        # helper function to register buffer from float64 to float32

        register_buffer = lambda name, val: self.register_buffer(
            name, val.to(torch.float32)
        )

        register_buffer("betas", betas)

        betas_cumsum = torch.cumsum(betas, dim=0)

        register_buffer("betas_cumsum", betas_cumsum)

        register_buffer("loss_weight", torch.ones_like(betas))

    @torch.no_grad()
    def p_sample_loop(self, batch, return_what="x0", clip=False):
        b = batch.shape[0]
        x = batch

        if return_what == "progression":
            out = torch.zeros(b, self.num_timesteps + 1, *x.shape[1:]).to(x.device)
            out[:, -1] = batch.clone()

        if return_what == "transform":
            out = torch.zeros(self.num_timesteps, b, 2, 2).to(
                dtype=x.dtype, device=x.device
            )

        for t in reversed(range(0, self.num_timesteps)):
            batched_t = torch.full((b,), t, device=batch.device, dtype=torch.long)
            model_out = self.model(x, batched_t)

            if self.objective == "pred_noise":
                pred_noise = model_out.transpose(-2, -1)
                # Use expm(-A) = inv(expm(A))
                mat = expm(-(self.betas[t] / self.betas_cumsum[t]) * pred_noise)
                x = x @ mat
                if t == 0 and clip:
                    x.clamp_(-1.0, 1.0)
            elif self.objective == "pred_x_start":
                # TODO how to compute noise from \hat{x_0}?
                raise NotImplementedError

            if return_what == "progression":
                out[:, t] = x.clone()
            elif return_what == "transform":
                out[t] = mat.clone()

        if return_what == "x0":
            return x
        elif return_what == "progression":
            return out
        elif return_what == "transform":
            out = reduce(torch.bmm, out)
            return x, out.transpose(-2, -1)
        else:
            raise ValueError

    @torch.no_grad()
    def sample(self, batch, return_what="x0", clip=False, return_GL2_mat=False):
        # Get sample from the prior
        # Just call q_sample with t = num_timesteps
        B = batch.shape[0]
        t = torch.full(
            (B,), self.num_timesteps - 1, device=batch.device, dtype=torch.long
        )
        new_batch, GL2_mat = self.q_sample(batch, t)
        GL2_mat = expm(GL2_mat)

        if return_GL2_mat:
            return (
                self.p_sample_loop(new_batch, return_what=return_what, clip=clip),
                GL2_mat,
            )
        else:
            return self.p_sample_loop(new_batch, return_what=return_what, clip=clip)

    def group_action_sample(self, batch):
        B = batch.shape[0]
        t = torch.full(
            (B,), self.num_timesteps - 1, device=batch.device, dtype=torch.long
        )
        new_batch, rho_1 = self.q_sample(batch, t)
        rho_1 = expm(rho_1)

        def sample_GL2(det_bound=2):
            out = torch.zeros_like(rho_1)
            for i in range(B):
                cond = False
                while not cond:
                    mat = (
                        torch.rand(2, 2, device=rho_1.device) * 2 * det_bound
                        - det_bound
                    )
                    det = mat[0, 0] * mat[1, 1] - mat[0, 1] * mat[1, 0]

                    cond = 1 / det_bound <= det <= det_bound

                out[i] = mat

            return out

        rho_2 = sample_GL2()

        new_batch = new_batch @ rho_2.transpose(-2, -1)

        _, rho_3 = self.p_sample_loop(new_batch, return_what="transform")

        out = rho_3 @ rho_2 @ rho_1

        if (detm(out) < 0).any():
            print("[Group Action Sample] det(out) < 0")

        return out

    @autocast("cuda", enabled=False)
    @torch.no_grad()
    def q_sample(self, x_0, t):
        b = x_0.shape[0]
        x = x_0.clone()

        # Use cur_t as t[0]+1 as t=0 means sample x_1 from x_0
        cur_t = min(t[0] + 1, self.num_timesteps)

        # [t, b, 2, 2]
        if isinstance(
            self.distributions_distribution, torch.distributions.Distribution
        ):
            # [t, b, 2, 2]
            bdist = self.distributions_distribution.expand([cur_t, b])
            noise = bdist.sample().to(x.device)
        else:
            noise = self.distributions_distribution([cur_t, b]).to(x.device)

        mats = expm(self.betas[:cur_t, None, None, None] * noise)

        # Reduce over first dim, [b, 2, 2]
        # final_mat = mats[0].T @ mats[1].T ... @ mats[t-1].T
        final_mat = reduce(torch.bmm, mats.transpose(-2, -1))

        x = x @ final_mat

        # Return x_t and cumulative noise as target
        return x, blogm(final_mat.transpose(-2, -1))

    def p_losses(self, x_0, t, noise=None):
        x_t, noise = self.q_sample(x_0=x_0, t=t)

        # predict and take gradient step
        model_out = self.model(x_t, t)

        if self.objective == "pred_noise":
            target = noise
        elif self.objective == "pred_x_start":
            target = x_0
        else:
            raise ValueError(f"Unknown objective {self.objective}")

        loss = F.mse_loss(model_out, target)
        # loss = reduce(loss, "b ... -> b", "mean")

        # loss = loss * extract(self.loss_weight, t, loss.shape)
        return loss.mean()

    def forward(self, x_0):
        b = x_0.shape[0]

        # sample a random timestep, t=0 means sample x_1 from x_0
        t = torch.randint(
            0, self.num_timesteps, (1,), device=x_0.device, dtype=torch.int64
        ).repeat(b)

        time_bin = np.digitize(t[0].item(), self.bins).item()

        return {"loss": self.p_losses(x_0, t), "t": t, "time_bin": time_bin}
