import torch

def compute_loss(score_nn, x, beta, C, T, device, dtype):

    B,D = x.shape # batch size, sample dimensions of input training data x
    # stratified time sample

    eps = 1e-2
    s = (torch.arange(B, device=device, dtype=dtype) +
        torch.rand(B, device=device, dtype=dtype)) / B          # (B,)
    t = s[:, None] * (T - 2*eps) + eps                           # (B,1)
    t = t[torch.randperm(B, device=device)]                      # shuffle


    S_t = torch.exp(-0.5 * beta* t)  # (B,1)
    mean_t = S_t * x                                # (B,D)
    var_t  = C * (1.0 - S_t**2)                     # (B,D), C is (1,D), S_t is (B,1)
    var_t = var_t.clamp_min(1e-8)

    x_t = mean_t + torch.sqrt(var_t) * torch.randn_like(x)  # (B,D)

    score = -(x_t - mean_t) / var_t * C # (B,D), elementwise

    pred_score = score_nn(t, x_t)

    loss = torch.mean(torch.sum((pred_score - score) ** 2, dim=1))

    return loss



def compute_distance(score_nn, x, u, beta, C, alpha, T):
    # Given the target measure mu as the mixture alpha N(u,diag(q)) + (1-alpha) N(u, diag(q)),
    # this computes the true score 
    # and the averaged distance between the score_nn and the true score, given some input data samples x of mu

    B, D = x.shape # Batch size and dimension of target measure training data x 
    device = x.device
    dtype = x.dtype

    # Sample t uniformly in (0,1]
    t = torch.rand((B, 1), dtype=x.dtype, device=x.device) * (T - 1e-5) + 1e-5 

    alpha = torch.as_tensor(alpha, dtype=dtype, device=device)     # scalar


    ## sample forward noising process X_t | X_0 = x
    # conditional mean and diagonal variance of X_t | X_0=x
    S_t = torch.exp(-0.5 * beta * t)  # (B,1)
    mean_t = S_t * x                             # (B,D)
    var_t  = (1.0 - S_t**2) * C                  # (B,D), Q = diag(q)
    x_t = mean_t + torch.sqrt(var_t) * torch.randn_like(x)  # (B,D)

    
    ## MARGINAL means and variance for the COMPONENTS X_t \mid X_0 \sim N(m_k,C_k)
    mean_k_t =  S_t * u # m_k = plus/minus u.., m_k(t) = plus minus S_t u.
    # Variance is the same for both components
    var_k_t = (1.0 - S_t**2) * C  + (S_t**2) * C                           # (B,D)

    
    ## Compute weights w_k(t,x), use log densities for stability
    # Compute log N(x_t; plus/minus mu_t, diag(var_t)) up to a common constant.
    # include the log-det term; the constant -0.5*D*log(2π) cancels in softmax anyway.
    log_det = torch.sum(torch.log(var_k_t), dim=1, keepdim=True)      # (B,1)

    # quadratic terms for each component
    quad_pos = torch.sum((x_t - mean_k_t)**2 / var_k_t, dim=1, keepdim=True)  # (B,1)
    quad_neg = torch.sum((x_t + mean_k_t)**2 / var_k_t, dim=1, keepdim=True)  # (B,1)
    
    # log densities of positive and negative mean components
    log_phi_pos = -0.5 * (quad_pos + log_det)                       # (B,1)
    log_phi_neg = -0.5 * (quad_neg + log_det)                       # (B,1)

    # Include mixture weights alpha 
    log_w_pos_unnorm = torch.log(alpha) + log_phi_pos
    log_w_neg_unnorm = torch.log(1.0 - alpha) + log_phi_neg
    
    m = torch.maximum(log_w_pos_unnorm, log_w_neg_unnorm) # normalising weights in log sum space
    w_pos = torch.exp(log_w_pos_unnorm - m) / (torch.exp(log_w_pos_unnorm - m) + torch.exp(log_w_neg_unnorm - m))
    w_neg = 1-w_pos

    # sum up components
    true_score = - (w_pos * (x_t - mean_k_t)/var_k_t + w_neg * (x_t + mean_k_t)/var_k_t)

    pred_score = score_nn(t, x_t)                                   # (B,D)
    dist = torch.linalg.norm(pred_score - C*true_score, dim=1, keepdim=True)  # (B,1), multiply true score by C (=diag(q))

    return torch.mean(dist)



def generate_samples(score_nn, B, D, N, beta, C, T, device, dtype, eps=1e-2):

    eps = torch.finfo(dtype).eps
    t_grid = torch.linspace(T, eps, N, device=device, dtype=dtype)  # (N,)
    sqrt_C = torch.sqrt(C) #(1,D)

    # initial distribution N(0, Q)
    x = sqrt_C * torch.randn((B, D), device=device, dtype=dtype)

    with torch.no_grad():
        for i in range(N - 1):
            t = t_grid[i]
            dt = t_grid[i + 1] - t_grid[i]    # negative

            t_batch = t.expand(B, 1)

            # Explicit nonlinear part (score term) evaluated at (t_n, x_n)
            score = score_nn(t_batch, x).detach()                 # (B,D)
            drift_explicit = -beta * score                    # (B,D)

            # Noise term: sqrt(beta) * sqrt(Q) * sqrt(|dt|) * N(0,I)
            noise = (torch.sqrt(torch.as_tensor(beta)) * sqrt_C * torch.sqrt(torch.abs(dt)) * torch.randn_like(x))    # (B,D)

            # Implicit EM for the linear part -0.5*beta*x:
            # x_{n+1} = x_n + dt*(-0.5*beta*x_{n+1} + drift_explicit) + noise
            denom = 1.0 + 0.5 * beta * dt                         # scalar (dt<0)
            x = (x + dt * drift_explicit + noise) / denom         # elementwise / scalar

    return x