import torch


# TODO: high-dim p_t and grad_log_p_t may not support batch operation for t.
def p_t(x_t, alpha_t, beta_t, mu, sigma_sq, w, log=False):
    mu = mu.to(x_t.device)
    sigma_sq = sigma_sq.to(x_t.device)
    w = w.to(x_t.device)
    # sigma_sq = sigma_sq ** 0.5
    n_samples = x_t.shape[0]  # Number of samples in the batch
    d = x_t.shape[1]  # Dimensionality of each sample
    k = w.shape[0]  # Number of mixture components

    # Ensure mus and sigmas are properly expanded for batch operations
    # mu_t = alpha_t.repeat(1, d) * mu
    alpha_t = alpha_t[0][0]
    beta_t = beta_t[0][0]
    mu_t = alpha_t * mu
    # sigma_t = alpha_t ** 2 * sigma_sq + beta_t ** 2 * torch.eye(d, device=x_t.device)
    # sigma_t = (alpha_t ** 2 * sigma_sq + beta_t ** 2).unsqueeze(-1) * torch.eye(d, device=x_t.device).unsqueeze(0).expand(k, d, d)
    # sigma_t = alpha_t ** 2 * sigma_sq + beta_t ** 2
    sigma_t = alpha_t ** 2 * sigma_sq + beta_t ** 2 * torch.eye(d, device=x_t.device)

    # Inverse and determinant of covariance matrix
    sigma_t_inv = torch.inverse(sigma_t)
    det_sigma_t = torch.det(sigma_t)

    # Broadcasting mu_t for each sample and mixture component
    mu_t_expanded = mu_t.unsqueeze(1).expand(k, n_samples, d)
    x_t_expanded = x_t.unsqueeze(0).expand(k, n_samples, d)
    diff = x_t_expanded - mu_t_expanded

    # Perform matrix multiplication and exponent calculation with einsum
    exp_component = torch.exp(-0.5 * torch.einsum('kni,kij,knj->kn', diff, sigma_t_inv, diff))
    normal_density = exp_component / torch.sqrt((2 * torch.pi) ** d * det_sigma_t.unsqueeze(1))

    weighted_density = w.unsqueeze(1) * normal_density
    p_x_t = weighted_density.sum(0).unsqueeze(-1)

    return torch.log(p_x_t) if log else p_x_t

def grad_log_p_t(x_t, alpha_t, beta_t, mu, sigma_sq, w):
    # sigma_sq = sigma_sq ** 0.5
    n_samples = x_t.shape[0]  # Number of samples in the batch
    d = x_t.shape[1]  # Dimensionality of each sample
    k = w.shape[0]  # Number of mixture components

    # Ensure mus and sigmas are properly expanded for batch operations
    alpha_t = alpha_t[0][0]
    beta_t = beta_t[0][0]
    mu_t = alpha_t * mu
    # sigma_t = alpha_t ** 2 * sigma_sq + beta_t ** 2 * torch.eye(d, device=x_t.device)
    # sigma_t = (alpha_t ** 2 * sigma_sq + beta_t ** 2).unsqueeze(-1) * torch.eye(d, device=x_t.device).unsqueeze(0).expand(k, d, d)
    # sigma_t = alpha_t ** 2 * sigma_sq + beta_t ** 2
    sigma_t = alpha_t ** 2 * sigma_sq + beta_t ** 2 * torch.eye(d, device=x_t.device)
    # Inverse and determinant of covariance matrix
    sigma_t_inv = torch.inverse(sigma_t)
    det_sigma_t = torch.det(sigma_t)

    # Broadcasting mu_t for each sample and mixture component
    mu_t_expanded = mu_t.unsqueeze(1).expand(k, n_samples, d)
    x_t_expanded = x_t.unsqueeze(0).expand(k, n_samples, d)
    diff = x_t_expanded - mu_t_expanded

    # Perform matrix multiplication and exponent calculation with einsum
    exp_component = torch.exp(-0.5 * torch.einsum('kni,kij,knj->kn', diff, sigma_t_inv, diff))
    normal_density = exp_component / torch.sqrt((2 * torch.pi) ** d * det_sigma_t.unsqueeze(1))

    weighted_density = w.unsqueeze(1) * normal_density
    p_x_t = weighted_density.sum(0)

    # Accumulate the weighted sum for the gradient calculation
    weighted_sum = torch.einsum('kn,kni->ni', weighted_density, torch.einsum('kij,knj->kni', sigma_t_inv, diff))
    # weighted_sum = torch.einsum('nk,nkd,kde->nd', normal_density, diff, sigma_t_inv)  # [N, D]

    # Compute the gradient of the log probability
    grad_log_p = -weighted_sum / (p_x_t.unsqueeze(-1) + 1e-15)

    return grad_log_p

def grad_log_p_t_test(x_t, alpha_t, beta_t, mu, sigma_sq, w):
    alpha_t = alpha_t[0][0]
    beta_t = beta_t[0][0]
    mus_t = alpha_t * mu + beta_t  # [K, D]
    sigma_t = alpha_t ** 2 * sigma_sq + beta_t ** 2 * torch.eye(x_t.shape[-1], device=x_t.device)# [K, D, D]

    # 
    sigma_inv = torch.linalg.inv(sigma_t)  # [K, D, D]
    sigma_det = torch.linalg.det(sigma_t)  # [K]

    # 
    diff = x_t[:, None, :] - mus_t[None, :, :]  # [N, K, D]
    exponent = -0.5 * torch.einsum('nkd,kde,nke->nk', diff, sigma_inv, diff)  # [N, K]
    normal_density = torch.exp(exponent) / (torch.sqrt((2 * torch.pi) ** x_t.shape[-1] * sigma_det))  # [N, K]
    normal_density *= w[None, :]  # [N, K]

    # 
    p_x_t = torch.sum(normal_density, dim=-1, keepdim=True)  # [N, 1]

    # score
    weighted_sum = torch.einsum('nk,nkd,kde->nd', normal_density, diff, sigma_inv)  # [N, D]
    grad_log_p = -weighted_sum / (p_x_t + 1e-12)  # [N, D]    return grad_log_p
    return -grad_log_p

# def grad_log_p_t(x_t, alpha_t, beta_t, mu, sigma_sq, w):
#     alpha_t = alpha_t.repeat(0, mu.shape[1])
#     mus_t = alpha_t * mu.squeeze(-1)
#     sigmas_t = alpha_t ** 2 * sigma_sq.squeeze(-1) + beta_t ** 2
#
#     # 
#     # ps = torch.exp(-0.5 * (x_t[:, None, None] - mus_t)**2 / sigmas_t) / torch.sqrt(2 * torch.pi * sigmas_t) * w[:,None]
#     ps = torch.exp(-0.5 * (x_t[:, None] - mus_t) ** 2 / sigmas_t) / torch.sqrt(2 * torch.pi * sigmas_t) * w[None, :]
#     # 
#     grad_log_p = -(torch.sum((x_t[:, None] - mus_t) * ps /sigmas_t, -1))  / (torch.sum(ps, -1) + 1e-12)
#     return grad_log_p.unsqueeze(-1)
#
# def p_t(x_t, alpha_t, beta_t, mu, sigma_sq, w):
#     x_t = x_t.squeeze(-1)
#     mus_t = alpha_t * mu.squeeze(-1)
#     sigmas_t = alpha_t ** 2 * sigma_sq.squeeze(-1) + beta_t ** 2
#
#     # 
#     # ps = torch.exp(-0.5 * (x_t[:, None, None] - mus_t)**2 / sigmas_t) / torch.sqrt(2 * torch.pi * sigmas_t) * w[:,None]
#     ps = torch.exp(-0.5 * (x_t[:, None] - mus_t) ** 2 / sigmas_t) / torch.sqrt(2 * torch.pi * sigmas_t) * w[None, :]
#     return torch.sum(ps, -1).unsqueeze(-1)

def grad_log_p_t_1d(x_t, alpha_t, beta_t, mu, sigma_sq, w):
    x_t = x_t.squeeze(-1)
    mus_t = alpha_t * mu.squeeze(-1)
    sigmas_t = alpha_t ** 2 * sigma_sq.squeeze(-1) + beta_t ** 2

    # 
    # ps = torch.exp(-0.5 * (x_t[:, None, None] - mus_t)**2 / sigmas_t) / torch.sqrt(2 * torch.pi * sigmas_t) * w[:,None]
    ps = torch.exp(-0.5 * (x_t[:, None] - mus_t) ** 2 / sigmas_t) / torch.sqrt(2 * torch.pi * sigmas_t) * w[None, :]
    # 
    grad_log_p = -(torch.sum((x_t[:, None] - mus_t) * ps /sigmas_t, -1))  / (torch.sum(ps, -1) + 1e-12)
    return grad_log_p.unsqueeze(-1)

def p_t_1d(x_t, alpha_t, beta_t, mu, sigma_sq, w):
    x_t = x_t.squeeze(-1)
    mus_t = alpha_t * mu.squeeze(-1)
    sigmas_t = alpha_t ** 2 * sigma_sq.squeeze(-1) + beta_t ** 2

    # 
    # ps = torch.exp(-0.5 * (x_t[:, None, None] - mus_t)**2 / sigmas_t) / torch.sqrt(2 * torch.pi * sigmas_t) * w[:,None]
    ps = torch.exp(-0.5 * (x_t[:, None] - mus_t) ** 2 / sigmas_t) / torch.sqrt(2 * torch.pi * sigmas_t) * w[None, :]
    return torch.sum(ps, -1).unsqueeze(-1)

def v_t(x_t, alpha_t, beta_t, alpha_t_dot, beta_t_dot, mu, sigma_sq, w):
    grad_log_p = grad_log_p_t(x_t, alpha_t, beta_t, mu, sigma_sq, w)
    return alpha_t_dot / alpha_t * x_t - beta_t * (beta_t_dot - (alpha_t_dot / alpha_t * beta_t)) * grad_log_p

def v_t_1d(x_t, alpha_t, beta_t, alpha_t_dot, beta_t_dot, mu, sigma_sq, w):
    grad_log_p = grad_log_p_t_1d(x_t, alpha_t, beta_t, mu, sigma_sq, w)
    alpha_t = alpha_t
    beta_t = beta_t
    return alpha_t_dot / alpha_t * x_t - beta_t * (beta_t_dot - (alpha_t_dot / alpha_t * beta_t)) * grad_log_p

def schedule(t, type='vp'):
    if type == 'vp' or type == 'subvp':
        a = 20 - 0.1
        b = 0.1
        alpha_t = torch.exp(-0.25 * (1 - t) ** 2 * a - 0.5 * (1 - t) * b)
        beta_t = (1 - alpha_t**2) ** 0.5 if type == "vp" else 1 - alpha_t**2
    elif type == 'linear':
        alpha_t = t
        beta_t = 1-t
    else:
        raise NotImplementedError()
    return alpha_t, beta_t

def abdot(t, type='vp'):
    if type == 'vp' or type == 'subvp':
        a = 20 - 0.1
        b = 0.1
        alpha_t = torch.exp(-0.25 * (1 - t) ** 2 * a - 0.5 * (1 - t) * b)
        beta_t = (1 - alpha_t**2) ** 0.5 if type == "vp" else 1 - alpha_t**2
        alpha_t_dot = alpha_t * (0.5 * a * (1 - t) + 0.5 * b)
        if type == "vp":
            beta_t_dot = 0.5 * (((1 - alpha_t**2)+1e-8)) ** (-0.5) * (-2 * alpha_t) * alpha_t_dot
        elif type == "subvp":
            beta_t_dot = (-2 * alpha_t) * alpha_t_dot
    elif type == 'linear':
        alpha_t_dot = 1
        beta_t_dot = -1
    else:
        raise NotImplementedError()
    return alpha_t_dot, beta_t_dot
