import torch
from torch import nn, Tensor
from typing import Callable, Sequence, Dict, Tuple
from azula.nn.embedding import SineEncoding
from azula.nn.unet import UNet
from azula.noise import Schedule
from azula.denoise import KarrasDenoiser, DiracPosterior
from azula.guidance import MMPSDenoiser


class NavierStokesBackbone(nn.Module):
    """
    Backbone of the denoiser for 2D Navier-Stokes.
    Argument(s):
        - in_channels (int): number of input channels.
        - out_channels (int): number of output channels.
        - cond_channels (int): number of condition channels.
        - hid_channels (Sequence[int]): numbers of channels at each depth.
        - hid_blocks (Sequence[int]): numbers of hidden blocks at each depth.
        - attention_heads (Dict[int, int]): number of attention heads at each depth.
        - spatial (int): number of spatial dimensions.
        - periodic (bool): whether the spatial dimensions are periodic or not.
        - identity_init (bool): initialize down/upsampling convolutions as identity.
    """
    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        cond_channels: int = 2,
        hid_channels: Sequence[int] = [64, 128, 256, 512],
        hid_blocks: Sequence[int] = [3, 3, 3, 3],
        attention_heads: Dict[int, int] = {3: 4},
        spatial: int = 2,
        periodic: bool = True,
        identity_init: bool = True,
    ):
        super().__init__()

        # Embedding for the time
        self.emb_features = 256
        self.time_embedding = nn.Sequential(
            SineEncoding(self.emb_features),
            nn.Linear(self.emb_features, self.emb_features),
            nn.SiLU(),
            nn.Linear(self.emb_features, self.emb_features),
        )

        # UNet to approximate the velocity
        self.unet = UNet(
            in_channels=in_channels,
            out_channels=out_channels,
            cond_channels=cond_channels,
            hid_channels=hid_channels,
            hid_blocks=hid_blocks,
            attention_heads=attention_heads,
            spatial=spatial,
            periodic=periodic,
            identity_init=identity_init,
        )

    def forward(self, z_kp1_t: Tensor, t: Tensor, x_k: Tensor) -> Tensor:
        """
        Forward (t, z^{k+1}_{t}, x^{k}) through the backbone.
        Input(s):
            - z_kp1_t (Tensor): noised next normalized residual at times t of the difffusion process with dimension (batch_size, 1, 128, 128).
            - t (Tensor): time of the diffusion process with dimension (batch_size,).
            - x_k (Tensor): clean normalized 2 previous states of the system with dimension (batch_size, 2, 128, 128).
        """
        time_emb = self.time_embedding(t)
        return self.unet(x=z_kp1_t, mod=time_emb, cond=x_k)


class KarrasScheduler(Schedule):
    """
    Create the noise schedule of "Elucidating the Design Space of Diffusion-Based Generative Models".
    See https://arxiv.org/abs/2206.00364 for more details.
    Argument(s):
        - sigma_min (float): the minimum noise scale.
        - sigma_max (float): the maximum noise scale.
    """

    def __init__(
        self, sigma_min: float = 1e-3, sigma_max: float = 1e3, rho: float = 7.0
    ):
        self.rho = rho
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def __call__(self, t: Tensor) -> Tuple[Tensor, Tensor]:
        return self.alpha(t), self.sigma(t)

    def alpha(self, t: Tensor) -> Tensor:
        return torch.ones_like(t)

    def sigma(self, t: Tensor) -> Tensor:
        a = self.sigma_max ** (1.0 / self.rho)
        b = self.sigma_min ** (1.0 / self.rho)
        return torch.pow(a + (1.0 - t) * (b - a), self.rho)


class NavierStokesDenoiser(KarrasDenoiser):
    """
    Karras denoiser to predict E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}] for the 2D Navier-Stokes system.
    """
    def __init__(self):
        super().__init__(
            backbone=NavierStokesBackbone(),
            schedule=KarrasScheduler(),
        )

    def forward(self, z_kp1_t: Tensor, t: Tensor, x_k: Tensor):  # type: ignore
        """
        Estimate E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}].
        Input(s):
            - z_kp1_t (Tensor): noised next normalized residuals at times t of the diffusion process with dimension (batch_size, 1, 128, 128).
            - t (Tensor): time of the diffusion process with dimension (batch_size,).
            - x_k (Tensor): clean normalized 2 previous states of the system with dimension (batch_size, 2, 128, 128).
        """
        return super().forward(x_t=z_kp1_t, t=t, x_k=x_k)

    def loss(self, z_kp1: Tensor, t: Tensor, x_k: Tensor):  # type: ignore
        """
        Compute the loss.
        Input(s):
            - z_kp1 (Tensor): clean next normalized residuals at times t of the diffusion process with dimension (batch_size, 1, 128, 128).
            - t (Tensor): time of the diffusion process with dimension (batch_size,).
            - x_k (Tensor): clean normalized 2 previous states of the system with dimension (batch_size, 2, 128, 128).
        """
        return super().loss(x=z_kp1, t=t, x_k=x_k)  # type: ignore


class ConditionalMMPSDenoiser(MMPSDenoiser):

    mean_x: Tensor
    std_x: Tensor
    std_z: Tensor
    lower: Tensor
    upper: Tensor


    """
    Conditional denoiser to predict E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}, y^{k+1}] for the 2D Navier-Stokes system with MMPS.
    Argument(s):
        - denoiser (NavierStokesDenoiser): a denoiser trained to estimate E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}].
        - y (Tensor): observation of the next state y^{k+1} with dimension (batch_size, d).
        - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 1, 128, 128) to (batch_size, d).
        - sigma_y (Tensor): covariance matrix of the observations with dimension (d,).
        - mean_x (Tensor): mean of states with dimension (1, 128, 128).
        - std_x (Tensor): standard deviation of states with dimension (1, 128, 128).
        - std_z (Tensor): standard deviation of residuals with dimension (1, 128, 128).
        - lower (Tensor): lower bound for the estimation with dimension (1, 128, 128).
        - upper (Tensor): upper bound for the estimation with dimension (1, 128, 128).
        - num_iterations (int): number of iterations to solve the system in MMPS.
    """
    def __init__(self,
        denoiser: NavierStokesDenoiser,
        y: Tensor,
        H: Callable[[Tensor], Tensor],
        sigma_y: Tensor,
        mean_x: Tensor,
        std_x: Tensor,
        std_z: Tensor,
        lower: Tensor,
        upper: Tensor,
        num_iterations: int = 1,
    ):
        # Use the init of MMPS
        super().__init__(
            denoiser=denoiser,
            y=y,
            A=H,
            cov_y=sigma_y,
            solver='gmres',
            iterations=num_iterations
        )

        # Save stats
        self.register_buffer("mean_x", mean_x)
        self.register_buffer("std_x", std_x)
        self.register_buffer("std_z", std_z)

        # Save lower and upper bounds
        self.register_buffer("lower", lower)
        self.register_buffer("upper", upper)

    
    def unnormalize_states(self, x: Tensor) -> Tensor:
        """
        Function to unnnormalize a batch of states.
        Input(s):
            - x (Tensor): normalized states with dimension (batch_size, 1, 128, 128).
        """
        return self.std_x[None, :] * x + self.mean_x[None, :]
    
    def unnormalize_residuals(self, z: Tensor) -> Tensor:
        """
        Function to unnnormalize a batch of residuals.
        Input(s):
            - z (Tensor): normalized residuals with dimension (batch_size, 1, 128, 128).
        """
        return self.std_z[None, :] * z

    
    @torch.no_grad()
    def forward(self, x_t: Tensor, t: Tensor, x_k: Tensor) -> DiracPosterior: # type: ignore
        """
        Estimate E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}, y^{k+1}].
        Input(s):
            - x_t (Tensor): noised next normalized residuals at times t of the diffusion process with dimension (batch_size, 1, 128, 128).
            - t (Tensor): time of the diffusion process with dimension (batch_size,).
            - x_k (Tensor): clean normalized 2 previous states of the system with dimension (batch_size, 2, 128, 128).
        """
        alpha_t, sigma_t = self.schedule(t)
        gamma_t = sigma_t**2 / alpha_t

        with torch.enable_grad():
            x_t = x_t.detach().requires_grad_()
            q = self.denoiser(z_kp1_t=x_t, t=t, x_k=x_k)
            z_hat = q.mean
            x = self.unnormalize_residuals(z=z_hat) + self.unnormalize_states(x=x_k[:, -1, :, :].unsqueeze(1))
            y_hat = self.A(x)

        def A(v):
            return torch.func.jvp(self.A, (x.detach(),), (v,))[1]

        def At(v):
            return torch.autograd.grad(y_hat, x, v, retain_graph=True)[0]

        def cov_x(v):
            return gamma_t * self.std_z[None, :] * torch.autograd.grad(x, x_t, v, retain_graph=True)[0]

        def cov_y(v):
            return self.cov_y(v) + A(cov_x(At(v)))

        grad = self.y - y_hat
        grad = self.solve(A=cov_y, b=grad)
        grad = torch.autograd.grad(y_hat, x_t, grad)[0]

        posterior_clipped = torch.clamp(z_hat + gamma_t * grad, min=self.lower, max=self.upper)
        return DiracPosterior(mean=posterior_clipped)
    