import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils import spectral_norm
from azula.nn.embedding import SineEncoding
from azula.noise import Schedule
from typing import Callable, Tuple
from azula.denoise import KarrasDenoiser, DiracPosterior
from azula.guidance import MMPSDenoiser

class ResidualBlock(nn.Module):
    def __init__(self, dim: int, time_dim: int, residual_scale: float = 0.8):
        """
        Residual block used for the denoiser.
        Argument(s):
            - dim (int): dimension of inputs.
            - time (int): dimension of time embedding.
            - residual_scale (float): residual scaling coefficient.
        """
        super().__init__()
        self.dim = dim
        self.residual_scale = residual_scale

        # Linear layers
        self.fc1 = spectral_norm(nn.Linear(dim, dim))
        self.ln1 = nn.LayerNorm(dim)
        self.fc2 = spectral_norm(nn.Linear(dim, dim))
        self.ln2 = nn.LayerNorm(dim)

        # Time embedding layer
        self.time_mlp = nn.Linear(time_dim, dim * 2)

        # Activation function
        self.act = nn.SiLU()

    def forward(self, x: Tensor, t_emb: Tensor) -> Tensor:
        """
        Forward an input through the residual block.
        Input(s):
            - x (Tensor): input tensor with dimension (batch_size, dim).
            - t_emb (Tensor): time embeddings with dimension (batch_size, time_dim)
        """
        scale, shift = self.time_mlp(t_emb).chunk(2, dim=-1)
        h = self.fc1(x)
        h = self.ln1(h)
        h = self.act(h * (1.0 + scale) + shift)
        h = self.fc2(h)
        h = self.ln2(h)
        return x + self.residual_scale * h


class Lorenz63Backbone(nn.Module):
    """
    Backbone of the denoiser for the stochastic Lorenz63 system.
    Argument(s):
        - input_dim (int): dimension of states/residuals.
        - hidden_dim (int): dimension of hidden states within the network.
        - num_blocks (int): number of residual blocks.
        - time_dim (int): dimension of the time embeddings.
        - residual_scale (float): residual scaling coefficient.
    """

    def __init__(
        self,
        input_dim=3,
        hidden_dim=256,
        num_blocks=8,
        time_dim=256,
        residual_scale=0.8,
    ):
        # Parameters
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_blocks = num_blocks
        self.time_dim = time_dim
        self.residual_scale = residual_scale

        # Time embedding
        self.time_embedding = nn.Sequential(
            SineEncoding(self.time_dim),
            nn.Linear(self.time_dim, self.time_dim),
            nn.SiLU(),
            nn.Linear(self.time_dim, self.time_dim),
        )

        # Input projection layer
        self.input_projection = nn.Sequential(
            spectral_norm(nn.Linear(2 * self.input_dim, self.hidden_dim)),
            nn.LayerNorm(self.hidden_dim),
            nn.SiLU(),
            spectral_norm(nn.Linear(self.hidden_dim, self.hidden_dim)),
            nn.LayerNorm(self.hidden_dim),
            nn.SiLU(),
        )

        # Residual blocks
        self.blocks = nn.ModuleList(
            [
                ResidualBlock(
                    dim=self.hidden_dim,
                    time_dim=self.time_dim,
                    residual_scale=self.residual_scale,
                )
                for _ in range(self.num_blocks)
            ]
        )

        # Output projection layer
        self.output_projection = nn.Sequential(
            spectral_norm(nn.Linear(self.hidden_dim, self.hidden_dim)),
            nn.LayerNorm(self.hidden_dim),
            nn.SiLU(),
            nn.Linear(self.hidden_dim, 3),
        )
        with torch.no_grad():
            self.output_projection[-1].weight.zero_()  # type: ignore
            self.output_projection[-1].bias.zero_()  # type: ignore

        # Scaling for the output
        self.out_scale = nn.Sequential(
            nn.Linear(self.time_dim, 32),
            nn.SiLU(),
            nn.Linear(32, 3),
            nn.Softplus(),
        )

    def forward(self, z_kp1_t: Tensor, t: Tensor, x_k: Tensor) -> Tensor:
        """
        Forward (t, z^{k+1}_{t}, x^{k}) through the backbone.
        Input(s):
            - z_kp1_t (Tensor): noised next normalized residual at times t of the difffusion process with dimension (batch_size, 3).
            - t (Tensor): time of the diffusion process with dimension (batch_size,).
            - x_k (Tensor): clean normalized previous states of the system with dimension (batch_size, 3).
        """
        # Get the time embeddings
        time_embeddings = self.time_embedding(t)

        # Concatenate inputs
        inputs = torch.cat([x_k, z_kp1_t], dim=-1)

        # Input projection layer
        h = self.input_projection(inputs)

        # Pass through residual blocks (each uses t_emb for FiLM)
        for b in self.blocks:
            h = b(x=h, t_emb=time_embeddings)

        # Output projection layer
        output = self.output_projection(h)

        # Scaling coefficient
        scale = self.out_scale(time_embeddings)

        return scale * torch.tanh(output)
    


class KarrasScheduler(Schedule):
    """
    Create the noise schedule of "Elucidating the Design Space of Diffusion-Based Generative Models".
    See https://arxiv.org/abs/2206.00364 for more details.
    Argument(s):
        - sigma_min (float): the minimum noise scale.
        - sigma_max (float): the maximum noise scale.
    """

    def __init__(
        self, sigma_min: float = 1e-3, sigma_max: float = 1e3, rho: float = 7.0
    ):
        self.rho = rho
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def __call__(self, t: Tensor) -> Tuple[Tensor, Tensor]:
        return self.alpha(t), self.sigma(t)

    def alpha(self, t: Tensor) -> Tensor:
        return torch.ones_like(t)

    def sigma(self, t: Tensor) -> Tensor:
        a = self.sigma_max ** (1.0 / self.rho)
        b = self.sigma_min ** (1.0 / self.rho)
        return torch.pow(a + (1.0 - t) * (b - a), self.rho)


class Lorenz63Denoiser(KarrasDenoiser):
    """
    Karras denoiser to predict E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}] for the stochastic Lorenz63 system.
    """

    def __init__(self):
        super().__init__(
            backbone=Lorenz63Backbone(),
            schedule=KarrasScheduler(),
        )

    def forward(self, z_kp1_t: Tensor, t: Tensor, x_k: Tensor):  # type: ignore
        """
        Estimate E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}].
        Input(s):
            - z_kp1_t (Tensor): noised next normalized residuals at times t of the diffusion process with dimension (batch_size, 3).
            - t (Tensor): time of the diffusion process with dimension (batch_size,).
            - x_k (Tensor): clean normalized previous states of the system with dimension (batch_size, 3).
        """
        return super().forward(x_t=z_kp1_t, t=t, x_k=x_k)

    def loss(self, z_kp1: Tensor, t: Tensor, x_k: Tensor):  # type: ignore
        """
        Compute the loss.
        Input(s):
            - z_kp1 (Tensor): clean next normalized residuals at times t of the diffusion process with dimension (batch_size, 3).
            - t (Tensor): time of the diffusion process with dimension (batch_size,).
            - x_k (Tensor): clean normalized previous states of the system with dimension (batch_size, 3).
        """
        return super().loss(x=z_kp1, t=t, x_k=x_k)  # type: ignore


class ConditionalMMPSDenoiser(MMPSDenoiser):

    mean_x: Tensor
    std_x: Tensor
    std_z: Tensor
    lower: Tensor
    upper: Tensor


    """
    Conditional denoiser to predict E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}, y^{k+1}] for the stochastic Lorenz63 system with MMPS.
    Argument(s):
        - denoiser (Lorenz63Denoiser): a denoiser trained to estimate E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}].
        - y (Tensor): observation of the next state y^{k+1} with dimension (batch_size, d).
        - H (Callable[[Tensor], Tensor]): observation operator from (batch_size, 3) to (batch_size, d).
        - sigma_y (Tensor): covariance matrix of the observations with dimension (d,).
        - mean_x (Tensor): mean of states with dimension (3,).
        - std_x (Tensor): standard deviation of states with dimension (3,).
        - std_z (Tensor): standard deviation of residuals with dimension (3,).
        - num_iterations (int): number of iterations to solve the system in MMPS.
    """
    def __init__(self,
        denoiser: Lorenz63Denoiser,
        y: Tensor,
        H: Callable[[Tensor], Tensor],
        sigma_y: Tensor,
        mean_x: Tensor,
        std_x: Tensor,
        std_z: Tensor,
        num_iterations: int = 1,
    ):
        # Use the init of MMPS
        super().__init__(
            denoiser=denoiser,
            y=y,
            A=H,
            cov_y=sigma_y,
            solver='gmres',
            iterations=num_iterations
        )

        # Save stats
        self.register_buffer("mean_x", mean_x)
        self.register_buffer("std_x", std_x)
        self.register_buffer("std_z", std_z)

        # Save lower and upper bounds
        self.register_buffer("lower", torch.tensor([-3.75, -4., -3.15]))
        self.register_buffer("upper", torch.tensor([3.75, 4., 4.05]))

    
    def unnormalize_states(self, x: Tensor) -> Tensor:
        """
        Function to unnnormalize a batch of states.
        Input(s):
            - x (Tensor): normalized states with dimension (batch_size, 3).
        """
        return self.std_x[None, :] * x + self.mean_x[None, :]
    
    def unnormalize_residuals(self, z: Tensor) -> Tensor:
        """
        Function to unnnormalize a batch of residuals.
        Input(s):
            - z (Tensor): normalized residuals with dimension (batch_size, 3).
        """
        return self.std_z[None, :] * z

    
    @torch.no_grad()
    def forward(self, x_t: Tensor, t: Tensor, x_k: Tensor) -> DiracPosterior: # type: ignore
        alpha_t, sigma_t = self.schedule(t)
        gamma_t = sigma_t**2 / alpha_t

        with torch.enable_grad():
            x_t = x_t.detach().requires_grad_()
            q = self.denoiser(z_kp1_t=x_t, t=t, x_k=x_k)
            z_hat = q.mean
            x = self.unnormalize_residuals(z=z_hat) + self.unnormalize_states(x=x_k)
            y_hat = self.A(x)

        def A(v):
            return torch.func.jvp(self.A, (x.detach(),), (v,))[1]

        def At(v):
            return torch.autograd.grad(y_hat, x, v, retain_graph=True)[0]

        def cov_x(v):
            return gamma_t * self.std_z[None, :] * torch.autograd.grad(x, x_t, v, retain_graph=True)[0]

        def cov_y(v):
            return self.cov_y(v) + A(cov_x(At(v)))

        grad = self.y - y_hat
        grad = self.solve(A=cov_y, b=grad)
        grad = torch.autograd.grad(y_hat, x_t, grad)[0]

        posterior_clipped = torch.clamp(z_hat + gamma_t * grad, min=self.lower, max=self.upper)
        return DiracPosterior(mean=posterior_clipped)
    