
import torch
from utils import torch_dct  
from types import SimpleNamespace
from utils.misc import make_v_grid



def noising(x, args_sde, t = None):
    """
    x - (B,H,W) sample images

    Expects args_sde;
    H,W - (upsampled) image size
    upsample_factor - Int, power of 2.
    beta - a noise schedule function beta(t)
    beta_int - a function that integrates int_t^T beta(r)
    A - a linear drift operator in eigenvalue representaiton (H,W)
    Q - a linear covariance operator in eigenvalue represenation (H,W)
    """

    device = x.device 
    dtype = x.dtype
    B,H,W = x.shape

    T = args_sde.T
    Q = args_sde.Q
    A = args_sde.A 
    beta_int = args_sde.beta_int

    assert (H, W) == Q.shape
    assert (H, W) == A.shape


    # to fourier domain
    x_dct = torch_dct.dct_2d(x, norm="ortho")  # (B, H, W), real

    if t == None: # Sample uniform times on (0,T-1e-5)
        t = torch.rand((B,1), dtype = dtype, device = device) * (T- 1e-5) # (B,1) 

    # 1. Statistics of OU Bridge X_t \mid X_T = y, X_0 \sim N(0,Q/2A)
    I_t = beta_int(t)
    I_t = I_t.view(-1, 1, 1)        # (B,1,1) for broadcast
    S_t = torch.exp(- A[None,:,:] * I_t)        # (B,H,W) S_t = U(T,t) in notation above.
    mean_t = S_t * x_dct    # (B,H,W)
    var_t = (Q/ (2.0 * A))[None, :, :]* (1.0 - S_t**2) # (B,H,W)

    # draw Gaussian sample
    x_dct_t = mean_t + var_t**(0.5) * torch.randn_like(x_dct) # (B,H,W), samples of X_t \mid X_T = y

    # Inverse DCT to image space
    x_t = torch_dct.idct_2d(x_dct_t, norm="ortho")  # (B, H2, W2)

    # Back to (B,C,H,W) ----
    x_t = x_t.unsqueeze(1)  # (B,1,H,W)

    return (t,x_t)



def compute_loss(score_nn, x, args_sde):

    device, dtype = x.device, x.dtype
    B, C, H, W = x.shape
    assert C == 1

    T = args_sde.T
    Q = args_sde.Q
    A = args_sde.A 
    beta_int = args_sde.beta_int

    assert (H, W) == Q.shape
    assert (H, W) == A.shape


    # transform images to dct space
    x_dct = torch_dct.dct_2d(x[:, 0], norm="ortho")  # (B,H,W)

    # Sample t in (0, T)
    #t = torch.rand((B, 1), device=device, dtype=dtype) * (T - 1e-2)  # (B,1)
    # stratified sampler
    u = (torch.arange(B, device=device, dtype=dtype) +torch.rand(B, device=device, dtype=dtype)) / B
    t = u[:, None] * (T - 1e-2)
    t = t[torch.randperm(B)] # shuffle stratified time samples


    # 1. Statistics of OU Bridge X_t \mid X_T = y, X_0 \sim N(0,Q/2A)
    I_t = beta_int(t)
    I_t = I_t.view(-1, 1, 1)                            # (B,1,1) for broadcast
    S_t = torch.exp(-A[None, :, :] * I_t)               # (B,H,W)
    mean_t = S_t * x_dct                                # (B,H,W)
    var_t = (Q/ (2.0 * A))[None, :, :]* (1.0 - S_t**2)  # (B,H,W)
    eps = 1e-12
    var_t = torch.clamp(var_t, min=eps)


    # Sample x_t = X_t \mid X_T = y, X_0 \sim N(0,Q/2A) in DCT space
    x_dct_t = mean_t + torch.sqrt(var_t) * torch.randn_like(x_dct)  # (B,H,W)

    # True transition score (scaled): var_t * D_xt log p(t, xt; T, y)
    ou_dct_score = S_t * (x_dct - S_t * x_dct_t) # (B,H,W)


    # Network prediction in image space and transform to DCT space
    x_t = torch_dct.idct_2d(x_dct_t, norm="ortho").unsqueeze(1)            # (B,1,H,W)
    v = make_v_grid(B, H, W, device=device, dtype=dtype)                   # (B,2,H,W)
    pred_score = score_nn(x=x_t, temp=t.view(B), v=v)          # (B,1,H,W)
    pred_dct_score = torch_dct.dct_2d(pred_score[:, 0], norm="ortho")  # (B,H,W)

    # Error in DCT space
    err_dct = ou_dct_score - pred_dct_score                     # (B,H,W)  
    loss = (err_dct**2).mean()

    return loss





def sample_forced_diffusion(score_nn, B, t_grid, args_sde):
    
    # get SDE params         
    H = args_sde.H
    W = args_sde.W
    T = args_sde.T
    Q = args_sde.Q
    A = args_sde.A 
    beta_int = args_sde.beta_int
    beta = args_sde.beta

    device = Q.device
    dtype  = Q.dtype

    assert (H, W) == Q.shape
    assert (H, W) == A.shape

    N = len(t_grid) # new grid length because that thing deletes duplicates

    v = make_v_grid(B, H, W, device=device, dtype=dtype)  # (B,2,H,W)

    # Initial sample in DCT domain: invariant measure N(0, diag(Q/2A)
    C = Q/ (2.0 * A) # (H,W)
    Y = torch.sqrt(C)[None, :, :] * torch.randn(B, H, W, device=device, dtype=dtype)  # (B,H,W)


    with torch.no_grad():
        for i in range(N - 1):
            dt = t_grid[i + 1] - t_grid[i]  # scalar

            t_n   = t_grid[i].expand(B)     # (B,)
            t_np1 = t_grid[i + 1].expand(B) # (B,)

            beta_n   = beta(t_n).view(B, 1, 1)    # (B,1,1)
            beta_np1 = beta(t_np1).view(B, 1, 1)  # (B,1,1)

            diffusion_n = torch.sqrt(beta_n)      # (B,1,1)

            # Evaluate score in image space, then map to dct space
            x_img = torch_dct.idct_2d(Y, norm="ortho").unsqueeze(1)  # (B,1,H,W)
            score_img = score_nn(x=x_img, temp=t_n, v=v)             # (B,1,H,W)
            score_dct = torch_dct.dct_2d(score_img[:, 0], norm="ortho")  # (B,H,W)

            # Scale by 1/var_t = 1/Q(T,t)
            I_t = beta_int(t_n)
            I_t = I_t.view(B, 1, 1)
            S_t = torch.exp(-A[None, :, :] * I_t)               # (B,H,W)
            var_t = (Q/ (2.0 * A))[None, :, :]* (1.0 - S_t**2)  # (B,H,W)
            eps = 1e-12
            var_t = torch.clamp(var_t, min=eps)

            score_drift = score_dct / var_t

            # Scale by beta(t)Q 
            score_drift = beta_n * Q[None,:,:] * score_drift 

            # Noise in DCT space
            dW = torch.randn_like(Y)
            noise = diffusion_n * torch.sqrt(dt) * torch.sqrt(Q)[None, :, :]* dW

            rhs = Y + dt * score_drift + noise

            # Semi-implicit linear solve
            denom = 1.0 + dt * (beta_np1 * A[None, :, :])
            Y = rhs / denom

    return Y






def sample_forced_diffusion_trotter(score_nn, B, t_grid, args_sde):
    """
    Lie–Trotter split implicit EM sampler consistent with sample_forced_diffusion:
    Returns Y in DCT domain.
    """

    # get SDE params
    H = args_sde.H
    W = args_sde.W
    T = args_sde.T
    Q = args_sde.Q
    A = args_sde.A
    beta_int = args_sde.beta_int
    beta = args_sde.beta

    device = Q.device
    dtype = Q.dtype

    assert (H, W) == Q.shape
    assert (H, W) == A.shape

    N = len(t_grid)

    v = make_v_grid(B, H, W, device=device, dtype=dtype)  # (B,2,H,W)

    # Initial sample in DCT domain: invariant measure N(0, diag(Q/2A))
    C = Q / (2.0 * A)  # (H,W)
    Y = torch.sqrt(C)[None, :, :] * torch.randn(B, H, W, device=device, dtype=dtype)  # (B,H,W)

    eps = 1e-12

    def trotter_step(Y, t_n_scalar, t_np1_scalar, dt_scalar):
        # broadcast times to (B,)
        t_n = t_n_scalar.expand(B)
        t_np1 = t_np1_scalar.expand(B)

        beta_n = beta(t_n).view(B, 1, 1)        # (B,1,1)
        beta_np1 = beta(t_np1).view(B, 1, 1)    # (B,1,1)

        # ----- (B) explicit: score drift + noise at t_n -----
        x_img = torch_dct.idct_2d(Y, norm="ortho").unsqueeze(1)            # (B,1,H,W)
        score_img = score_nn(x=x_img, temp=t_n, v=v)                        # (B,1,H,W)
        score_dct = torch_dct.dct_2d(score_img[:, 0], norm="ortho")         # (B,H,W)

        I_t = beta_int(t_n).view(B, 1, 1)                                   # (B,1,1)
        S_t = torch.exp(-A[None, :, :] * I_t)                                # (B,H,W)
        var_t = (Q / (2.0 * A))[None, :, :] * (1.0 - S_t**2)                 # (B,H,W)
        var_t = torch.clamp(var_t, min=eps)

        score_drift = score_dct / var_t                                      # (B,H,W)
        score_drift = beta_n * Q[None, :, :] * score_drift                   # (B,H,W)

        dW = torch.randn_like(Y)
        noise = torch.sqrt(beta_n) * torch.sqrt(dt_scalar) * torch.sqrt(Q)[None, :, :] * dW

        Y_tilde = Y + dt_scalar * score_drift + noise

        # ----- (A) implicit: linear solve at t_{n+1} -----
        denom = 1.0 + dt_scalar * (beta_np1 * A[None, :, :])                 # (B,H,W)
        Y_next = Y_tilde / denom

        return Y_next

    with torch.no_grad():
        for i in range(N - 1):
            dt = t_grid[i + 1] - t_grid[i]
            Y = trotter_step(Y, t_grid[i], t_grid[i + 1], dt)

    return Y




def sample_forced_diffusion_strang(score_nn, B, t_grid, args_sde):
    """
    2nd-order (Strang) splitting.
    """

    # get SDE params
    H = args_sde.H
    W = args_sde.W
    T = args_sde.T
    Q = args_sde.Q
    A = args_sde.A
    beta_int = args_sde.beta_int
    beta = args_sde.beta

    device = Q.device
    dtype = Q.dtype

    assert (H, W) == Q.shape
    assert (H, W) == A.shape

    N = len(t_grid)

    v = make_v_grid(B, H, W, device=device, dtype=dtype)  # (B,2,H,W)

    # Initial sample in DCT domain: invariant measure N(0, diag(Q/2A))
    C = Q / (2.0 * A)  # (H,W)
    Y = torch.sqrt(C)[None, :, :] * torch.randn(B, H, W, device=device, dtype=dtype)  # (B,H,W)

    eps = 1e-12

    def A_half_backward_euler(Y, beta_t, dt_half):
        # beta_t: (B,1,1), dt_half: scalar
        denom = 1.0 + dt_half * (beta_t * A[None, :, :])  # (B,H,W)
        return Y / denom

    def B_full_explicit(Y, t_mid, dt):
        """
        Explicit score drift + noise at t_mid.
        t_mid is (B,) tensor.
        """
        beta_mid = beta(t_mid).view(B, 1, 1)                 # (B,1,1)

        # score in image space -> dct
        x_img = torch_dct.idct_2d(Y, norm="ortho").unsqueeze(1)           # (B,1,H,W)
        score_img = score_nn(x=x_img, temp=t_mid, v=v)                    # (B,1,H,W)
        score_dct = torch_dct.dct_2d(score_img[:, 0], norm="ortho")        # (B,H,W)

        # var_t from your code, evaluated at t_mid
        I_t = beta_int(t_mid).view(B, 1, 1)                                # (B,1,1)
        S_t = torch.exp(-A[None, :, :] * I_t)                               # (B,H,W)
        var_t = (Q / (2.0 * A))[None, :, :] * (1.0 - S_t**2)                # (B,H,W)
        var_t = torch.clamp(var_t, min=eps)

        # drift_B = beta(t) Q * (score/var_t)
        score_drift = score_dct / var_t                                     # (B,H,W)
        score_drift = beta_mid * Q[None, :, :] * score_drift                # (B,H,W)

        # noise_B
        dW = torch.randn_like(Y)
        noise = torch.sqrt(beta_mid) * torch.sqrt(dt) * torch.sqrt(Q)[None, :, :] * dW

        return Y + dt * score_drift + noise

    with torch.no_grad():
        for i in range(N - 1):
            t_n = t_grid[i]
            t_np1 = t_grid[i + 1]
            dt = t_np1 - t_n
            dt_half = 0.5 * dt

            # broadcast times to (B,)
            t_nB = t_n.expand(B)
            t_np1B = t_np1.expand(B)
            t_midB = (t_n + 0.5 * dt).expand(B)

            beta_n = beta(t_nB).view(B, 1, 1)
            beta_np1 = beta(t_np1B).view(B, 1, 1)

            # Strang: A_half(t_n) -> B_full(t_mid) -> A_half(t_{n+1})
            Y = A_half_backward_euler(Y, beta_n, dt_half)
            Y = B_full_explicit(Y, t_midB, dt)
            Y = A_half_backward_euler(Y, beta_np1, dt_half)

    return Y
