import torch
import numpy as np
from scipy.linalg import sqrtm

# Let's just implement the simplest thing for now: diagonal initial covariance and diagonal initial Hessian
# This requires the covariance object to store the current diagonal term and a list of update vectors used so far
# don't try to optimise with any fancy stuff for now, just get it working

class CovarianceHessianBFGS:
    """A class for storing the covariance and hessian and their inverses, and for updating them when moving in x or diffusion time t.
    Uses the BFGS update for the covariance, and derives the Hessian update from that. The BFGS update necessitates a data representation where
    the covariance matrix is represented as diag + UU^T - VV^T. All the linear algebra operators are then implemented for this data representation."""

    def __init__(self, init_denoiser_variance, init_noise_variance, data_dim, dtype=torch.complex128):
        #self.denoiser_variance = init_variance # the variance of the diagonal denoiser covariance matrix
        self.vectors_denoiser_cov_u = torch.zeros(data_dim, 0, dtype=dtype)
        self.vectors_inv_denoiser_cov_u = torch.zeros(data_dim, 0, dtype=dtype)
        self.vectors_hessian_u = torch.zeros(data_dim, 0, dtype=dtype)
        self.vectors_inv_hessian_u = torch.zeros(data_dim, 0, dtype=dtype)
        # The u-v distinction is relevant for the BFGS data representation. U-vectors are the positive low-rank updates, V-vectors are the negative low-rank updates
        self.vectors_denoiser_cov_v = torch.zeros(data_dim, 0, dtype=dtype)
        self.vectors_inv_denoiser_cov_v = torch.zeros(data_dim, 0, dtype=dtype)
        self.vectors_hessian_v = torch.zeros(data_dim, 0, dtype=dtype)
        self.vectors_inv_hessian_v = torch.zeros(data_dim, 0, dtype=dtype)
        self.diagonal_denoiser_cov = torch.ones(data_dim, dtype=dtype) * init_denoiser_variance
        self.diagonal_inv_denoiser_cov = 1/self.diagonal_denoiser_cov
        self.diagonal_hessian = (init_denoiser_variance/init_noise_variance - 1)/init_noise_variance * torch.ones(data_dim, dtype=dtype)
        self.diagonal_inv_hessian = 1/self.diagonal_hessian
        self.data_dim = data_dim
        self.dtype = dtype

    def to_complex(self, x):
        if x.dtype == torch.float32:
            return x.to(torch.complex64)
        elif x.dtype == torch.float64:
            return x.to(torch.complex128)
        else:
            return x

    def sqrtm(self, A):
        if A.shape[0] != 0 and A.shape[1] != 0:
            return torch.from_numpy(sqrtm(A.numpy()).astype(np.complex128)).to(A.dtype)
        else:
            return torch.zeros_like(A)

    def woodbury_inverse_from_diag_plus_lowrank(self, diag_inv, U):
        # calculate (diag + W W^T)^-1 = diag_inv - diag_inv W (I + W^T diag_inv W)^-1 W^T diag_inv
        # .. assuming that we have the diagonal inverse. 
        # returns V_inv, because of the negative sign of the update
        
        # shapes: 
        # diag_inv: (d) (corresponds to a (d,d) diagonal matrix)
        # U: (d, k) (corresponds to a (d,k) matrix)
        k = U.shape[-1]
        inner_inv = torch.linalg.inv(torch.eye(k) + U.T @ (diag_inv[...,None] * U))
        # make sure that symmetric
        inner_inv = (inner_inv + inner_inv.T) / 2
        
        # There is no guarantee that inner_inv is positive definite, so we cannot use Cholesky decomposition
        # instead, we use the eigendecomposition to get the matrix square root
        # D, W = torch.linalg.eig(inner_inv)
        # inner_inv_sqrt = W @ torch.diag_embed(self.to_complex(D).sqrt()) @ W.T
        
        # Nah instead use the LDL decomposition
        # LD, pivots = torch.linalg.ldl_factor(inner_inv, hermitian=False)
        # inner_inv_half = self.matrix_square_root(LD, pivots)
        # L = torch.tril(LD, -1) + torch.eye(k, device=LD.device, dtype=LD.dtype)
        # D = torch.diag(torch.diag(LD))
        # inner_inv_half = L @ D.sqrt()
        
        # Nah instead use scipy
        inner_inv_sqrt = self.sqrtm(inner_inv)
        
        V_inv = diag_inv[...,None] * (U @ inner_inv_sqrt)
        return V_inv

    def woodbury_inverse_from_diag_plus_lowrank_minus_lowrank(self, U, V, diag):
        diag_inv = 1/diag
        # Calculate A=(diag + UU^T)^-1 in format diag_inv - V_inv V_inv^T
        V_inv = self.woodbury_inverse_from_diag_plus_lowrank(diag_inv, U)
        # calculate (A - VV^T)^-1 in format A_inv + U_inv U_inv^T
        # (A - VV^T)^-1 = A^-1 + A^-1 V (I - V^T A^-1 V)^-1 V^T A^-1 
        # first calculate (I - V^T A^-1 V)^-1 = (I - V^T A_diag^-1 V + V^T V_inv V_inv^T V)^-1
        k = V.shape[-1]
        K = V_inv.T @ V
        inner_inv = torch.linalg.inv(torch.eye(k) - V.T @ (diag_inv[...,None] * V) + K.T @ K)
        #D, W = torch.linalg.eig(inner_inv)
        #inner_inv_sqrt = W @ torch.diag_embed(self.to_complex(D).sqrt()) @ W.T
        #V_inner_inv_sqrt = V @ inner_inv_sqrt
        
        # LD, _ = torch.linalg.ldl_factor(inner_inv, hermitian=False)
        # L = torch.tril(LD, -1) + torch.eye(k, device=LD.device, dtype=LD.dtype)
        # D = torch.diag(torch.diag(LD))
        # inner_inv_half = L @ D.sqrt()
        # LD, pivots = torch.linalg.ldl_factor(inner_inv, hermitian=False)
        # inner_inv_half = self.matrix_square_root(LD, pivots)
        inner_inv_sqrt = self.sqrtm(inner_inv)
        V_inner_inv_sqrt = V @ inner_inv_sqrt
        U_inv = diag_inv[...,None] * (V_inner_inv_sqrt) - V_inv @ (V_inv.T @ V_inner_inv_sqrt)
        return diag_inv, U_inv, V_inv
    
    def sherman_morrison_update(self, U, V, diag, v, pos):
        """ pos = True if we are adding a positive update, False if we are adding a negative update
         implements the Sherman-Morrison formula for computing (A+vv^T)^-1 or (A-vv^T)^-1, given that we have A^{-1}. 
         A^{-1}= diag + UU^T - VV^T, wiht the update being (A+vv^T)^-1 or (A-vv^T)^-1. 
         diag is the diagonal of the matrix, U and V are the U and V vectors of the original inverse matrix, 
         v is the update vector, and pos is a boolean indicating whether the update is positive or negative"""
        if pos:
            denominator = 1 + v.T @ (U @ (U.T @ v) - V @ (V.T @ v) + diag * v)
            if denominator <= 0:
                u_update = (U @ (U.T @ v) - V @ (V.T @ v) + diag * v) / (-denominator).sqrt()
                return torch.cat((U, u_update[:,None]), dim=-1), V, diag
            else:
                v_update = (U @ (U.T @ v) - V @ (V.T @ v) + diag * v) / denominator.sqrt()
                return U, torch.cat((V, v_update[:,None]), dim=-1), diag
        else:
            denominator = 1 - v.T @ (U @ (U.T @ v) - V @ (V.T @ v) + diag * v)
            if denominator <= 0:
                v_update = (U @ (U.T @ v) - V @ (V.T @ v) + diag * v) / (-denominator).sqrt()
                return U, torch.cat((V, v_update[:,None]), dim=-1), diag
            else:
                u_update = (U @ (U.T @ v) - V @ (V.T @ v) + diag * v) / denominator.sqrt()
                return torch.cat((U, u_update[:,None]), dim=-1), V, diag

    def sherman_morrison_double_update(self, U, V, diag, u, v):
        """ u is the positive update, v is the negative update
         (A + uu^T - vv^T)^-1
         first calculates (A + uu^T)^-1
         then calculates (A + uu^T - vv^T)^-1 based on that"""
        U_updated, V_updated, diag_updated = self.sherman_morrison_update(U, V, diag, u, True)
        U_updated, V_updated, diag_updated = self.sherman_morrison_update(U_updated, V_updated, diag_updated, v, False)
        return U_updated, V_updated, diag_updated

    def update_time_step(self, x_t, sigma_t, sigma_tnext, score_t):
        """Assumes that score_t has batch size 1, and there is no batch dimension"""
        # update the inverse denoiser covariance (only diagonal term necessary to change)        
        shape = x_t.shape # e.g., (bs, C, H, W)
        assert shape[0] == 1, "Batch size must be 1"
        x_t = self.to_complex(x_t).reshape(-1) # flatten to (C*H*W)
        score_t = self.to_complex(score_t).reshape(-1) # flatten to (C*H*W)
        
        k = self.vectors_denoiser_cov_u.shape[-1]
        self.diagonal_inv_denoiser_cov = self.diagonal_inv_denoiser_cov + (sigma_tnext**(-2) - sigma_t**(-2)) * torch.ones(self.data_dim)
        self.diagonal_denoiser_cov, self.vectors_denoiser_cov_u, self.vectors_denoiser_cov_v = self.woodbury_inverse_from_diag_plus_lowrank_minus_lowrank(self.vectors_inv_denoiser_cov_u, 
                                                                                                                                self.vectors_inv_denoiser_cov_v, self.diagonal_inv_denoiser_cov)
        
        # Then the hessian...
        new_diagonal_inv_hessian = self.diagonal_inv_hessian - (sigma_tnext**(2) - sigma_t**(2)) * torch.ones(self.data_dim)
        new_diag_hessian, new_u_hessian, new_v_hessian = self.woodbury_inverse_from_diag_plus_lowrank_minus_lowrank(self.vectors_inv_hessian_u, self.vectors_inv_hessian_v, new_diagonal_inv_hessian)

        # Score function at time t_next (new_hessian @ old_inv_hessian @ score_t)
        old_inv_hessian_score_t = self.diagonal_inv_hessian * score_t + self.vectors_inv_hessian_u @ (self.vectors_inv_hessian_u.T @ score_t) - self.vectors_inv_hessian_v @ (self.vectors_inv_hessian_v.T @ score_t)
        new_score_value = (new_diag_hessian * old_inv_hessian_score_t + new_u_hessian @ (new_u_hessian.T @ old_inv_hessian_score_t) - new_v_hessian @ (new_v_hessian.T @ old_inv_hessian_score_t)).real + 0j
        # Denoiser mean at time t_next
        new_denoiser_mean = (x_t + sigma_tnext**2 * new_score_value).real + 0j

        self.diagonal_inv_hessian = new_diagonal_inv_hessian
        self.diagonal_hessian, self.vectors_hessian_u, self.vectors_hessian_v  = new_diag_hessian, new_u_hessian, new_v_hessian

        return new_denoiser_mean.reshape(shape), new_score_value.reshape(shape)
    
    def denoiser_cov_vector_dot(self, v):
        shape = v.shape # e.g., (bs, C, H, W)
        v = self.to_complex(v).reshape(-1) # flatten to (C*H*W)
        return (self.diagonal_denoiser_cov * v + self.vectors_denoiser_cov_u @ (self.vectors_denoiser_cov_u.T @ v) - self.vectors_denoiser_cov_v @ (self.vectors_denoiser_cov_v.T @ v)).real.reshape(shape)
    
    def inv_denoiser_cov_vector_dot(self, v):
        shape = v.shape # e.g., (bs, C, H, W)   
        v = self.to_complex(v).reshape(-1) # flatten to (C*H*W)
        return (self.diagonal_inv_denoiser_cov * v + self.vectors_inv_denoiser_cov_u @ (self.vectors_inv_denoiser_cov_u.T @ v) - self.vectors_inv_denoiser_cov_v @ (self.vectors_inv_denoiser_cov_v.T @ v)).real.reshape(shape)
    
    def hessian_vector_dot(self, v):
        shape = v.shape # e.g., (bs, C, H, W)
        v = self.to_complex(v).reshape(-1) # flatten to (C*H*W)
        return (self.diagonal_hessian * v + self.vectors_hessian_u @ (self.vectors_hessian_u.T @ v) - self.vectors_hessian_v @ (self.vectors_hessian_v.T @ v)).real.reshape(shape)
    
    def inv_hessian_vector_dot(self, v):
        shape = v.shape # e.g., (bs, C, H, W)
        v = self.to_complex(v).reshape(-1) # flatten to (C*H*W)
        return (self.diagonal_inv_hessian * v + self.vectors_inv_hessian_u @ (self.vectors_inv_hessian_u.T @ v) - self.vectors_inv_hessian_v @ (self.vectors_inv_hessian_v.T @ v)).real.reshape(shape)
    
    def update_space_step(self, denoiser_mean_at_x, denoiser_mean_at_xnext, sigma_t, x, xnext):
        """BFGS update of the denoiser covariance and hessian and the inverses"""
        # update the denoiser covariance and hessian

        shape = x.shape # e.g., (bs, C, H, W)
        assert shape[0] == 1, "Batch size must be 1"
        x = self.to_complex(x).reshape(-1) # flatten to (C*H*W)
        xnext = self.to_complex(xnext).reshape(-1) # flatten to (C*H*W)
        denoiser_mean_at_x = self.to_complex(denoiser_mean_at_x).reshape(-1) # flatten to (C*H*W)
        denoiser_mean_at_xnext = self.to_complex(denoiser_mean_at_xnext).reshape(-1) # flatten to (C*H*W)

        dx = xnext - x
        de = sigma_t**2 * (denoiser_mean_at_xnext - denoiser_mean_at_x)
        gamma = 1/(dx @ de)

        # Update the denoiser covariance
        # The maths: Dcov -> DCov - DCov @ dx @ dx.T @ DCov / (dx.T @ DCov @ dx) + de @ de.T * gamma
        # need to calculate DCov @ dx (in the form of diag + UU^T - VV^T)
        denoiser_cov_dot_dx = self.denoiser_cov_vector_dot(dx)
        # then dx_dot_denoiser_cov_dot_dx = dx_dot_denoiser_cov @ dx
        dx_dot_denoiser_cov_dot_dx = denoiser_cov_dot_dx @ dx
        # then we're ready for the denoiser covariance update
        v = denoiser_cov_dot_dx / torch.sqrt(dx_dot_denoiser_cov_dot_dx)
        u = de * torch.sqrt(gamma)
        new_diagonal_denoiser_cov = self.diagonal_denoiser_cov
        new_vectors_denoiser_cov_u = torch.cat((self.vectors_denoiser_cov_u, u[:,None]), dim=-1)
        new_vectors_denoiser_cov_v = torch.cat((self.vectors_denoiser_cov_v, v[:,None]), dim=-1)
        
        # Update the inverse denoiser covariance
        # then the inverse covariance update. Calculate as two Sherman-Morrison updates based on the updates for the actual covariance
        # new_vectors_inverse_denoiser_cov_u, new_vectors_inverse_denoiser_cov_v, new_diagonal_inv_denoiser_cov = self.sherman_morrison_double_update(U=self.vectors_inv_denoiser_cov_u, 
        #                                                                         V=self.vectors_inv_denoiser_cov_v, diag=self.diagonal_inv_denoiser_cov, u=u, v=v)
        # instead use the Woodbury identity for now for simplicity
        new_diagonal_inv_denoiser_cov, new_vectors_inv_denoiser_cov_u, new_vectors_inv_denoiser_cov_v = self.woodbury_inverse_from_diag_plus_lowrank_minus_lowrank(U=new_vectors_denoiser_cov_u, 
                                                                                V=new_vectors_denoiser_cov_v, diag=new_diagonal_denoiser_cov)
        
        # Update the Hessian based on the denoiser covariance update
        # H = (Dcov/sigma^2 - I)/sigma^2
        new_diagonal_hessian = (new_diagonal_denoiser_cov / sigma_t**2 - torch.ones(self.data_dim)) / sigma_t**2
        u_hessian = u / sigma_t**2
        v_hessian = v / sigma_t**2
        new_vectors_hessian_u = torch.cat((self.vectors_hessian_u, u_hessian[:,None]), dim=-1)
        new_vectors_hessian_v = torch.cat((self.vectors_hessian_v, v_hessian[:,None]), dim=-1)
        
        new_diagonal_inv_hessian, new_vectors_inv_hessian_u, new_vectors_inv_hessian_v = self.woodbury_inverse_from_diag_plus_lowrank_minus_lowrank(U=new_vectors_hessian_u, 
                                                                                V=new_vectors_hessian_v, diag=new_diagonal_hessian)
        
        # new_vectors_inv_hessian_u, new_vectors_inv_hessian_v, new_diagonal_inv_hessian = self.sherman_morrison_double_update(U=self.vectors_inv_hessian_u, 
        #                                                                                 V=self.vectors_inv_hessian_v, diag=self.diagonal_inv_hessian, u=u_hessian, v=v_hessian)
        
        # Apply all the updates
        self.diagonal_denoiser_cov, self.vectors_denoiser_cov_u, self.vectors_denoiser_cov_v = new_diagonal_denoiser_cov, new_vectors_denoiser_cov_u, new_vectors_denoiser_cov_v
        self.diagonal_inv_denoiser_cov, self.vectors_inv_denoiser_cov_u, self.vectors_inv_denoiser_cov_v = new_diagonal_inv_denoiser_cov, new_vectors_inv_denoiser_cov_u, new_vectors_inv_denoiser_cov_v
        self.diagonal_hessian, self.vectors_hessian_u, self.vectors_hessian_v  = new_diagonal_hessian, new_vectors_hessian_u, new_vectors_hessian_v
        self.diagonal_inv_hessian, self.vectors_inv_hessian_u, self.vectors_inv_hessian_v = new_diagonal_inv_hessian, new_vectors_inv_hessian_u, new_vectors_inv_hessian_v

    def get_dense_matrices(self):
        denoiser_cov = self.diagonal_denoiser_cov[:,None] * torch.eye(self.data_dim) + self.vectors_denoiser_cov_u @ self.vectors_denoiser_cov_u.T - self.vectors_denoiser_cov_v @ self.vectors_denoiser_cov_v.T
        inv_denoiser_cov = self.diagonal_inv_denoiser_cov[:,None] * torch.eye(self.data_dim) + self.vectors_inv_denoiser_cov_u @ self.vectors_inv_denoiser_cov_u.T - self.vectors_inv_denoiser_cov_v @ self.vectors_inv_denoiser_cov_v.T
        hessian = self.diagonal_hessian[:,None] * torch.eye(self.data_dim) + self.vectors_hessian_u @ self.vectors_hessian_u.T - self.vectors_hessian_v @ self.vectors_hessian_v.T
        inv_hessian = self.diagonal_inv_hessian[:,None] * torch.eye(self.data_dim) + self.vectors_inv_hessian_u @ self.vectors_inv_hessian_u.T - self.vectors_inv_hessian_v @ self.vectors_inv_hessian_v.T
        return denoiser_cov, inv_denoiser_cov, hessian, inv_hessian

    def set_others_corresponding_to_current_denoiser_cov(self, sigma):
        self.diagonal_inv_denoiser_cov, self.vectors_inv_denoiser_cov_u, self.vectors_inv_denoiser_cov_v = self.woodbury_inverse_from_diag_plus_lowrank_minus_lowrank(self.vectors_denoiser_cov_u, self.vectors_denoiser_cov_v, self.diagonal_denoiser_cov)
        self.diagonal_hessian, self.vectors_hessian_u, self.vectors_hessian_v = (self.diagonal_denoiser_cov/sigma**2 - 1)/sigma**2, self.vectors_denoiser_cov_u/sigma**2, self.vectors_denoiser_cov_v/sigma**2
        self.diagonal_inv_hessian, self.vectors_inv_hessian_u, self.vectors_inv_hessian_v = self.woodbury_inverse_from_diag_plus_lowrank_minus_lowrank(self.vectors_hessian_u, self.vectors_hessian_v, self.diagonal_hessian)
        
        a,b,c,d = self.get_dense_matrices()
        # Check if matrices are positive/negative definite
        def is_positive_definite(matrix):
            try:
                torch.linalg.cholesky(matrix)
                return True
            except RuntimeError:
                return False
        assert is_positive_definite(a), "Denoiser covariance matrix is not positive definite"
        assert is_positive_definite(b), "Inverse denoiser covariance matrix is not positive definite"
        assert is_positive_definite(-c), "-Hessian matrix is not positive definite"
        assert is_positive_definite(-d), "-Inverse Hessian matrix is not positive definite"

def update_covariance(samples, denoiser_cov, inv_denoiser_cov, hessian, inv_hessian, score_value, denoiser_mean, schedule, t, tnext):
    """
    Update the denoiser covariance, hessian, score function, and denoiser mean for a batch of samples
    at a new time step using a Gaussian approximation of the noisy distribution.

    Args:
        samples (torch.Tensor): Batch of samples, shape (bs, d)
        denoiser_cov (torch.Tensor): Batch of denoiser covariance matrices, shape (bs, d, d)
        inv_denoiser_cov (torch.Tensor): Batch of inverse denoiser covariance matrices, shape (bs, d, d)
        hessian (torch.Tensor): Batch of hessian matrices, shape (bs, d, d)
        inv_hessian (torch.Tensor): Batch of inverse hessian matrices, shape (bs, d, d)
        score_value (torch.Tensor): Batch of score function values, shape (bs, d)
        denoiser_mean (torch.Tensor): Batch of denoiser mean values, shape (bs, d)
        schedule (callable): Function that returns the noise level at a given time
        t (float): Current time step
        tnext (float): Next time step

    Returns:
        tuple: Updated values for denoiser_cov, inv_denoiser_cov, hessian, inv_hessian, score_value, denoiser_mean
    """
    dim = samples.shape[-1]
    
    # Update the inverse covariance matrix
    new_inv_denoiser_cov = inv_denoiser_cov + (schedule(tnext)**(-2) - schedule(t)**(-2)) * torch.eye(dim)
    new_denoiser_cov = torch.linalg.inv(new_inv_denoiser_cov)
    
    new_inv_hessian = inv_hessian - (schedule(tnext)**(2) - schedule(t)**(2)) * torch.eye(dim)
    new_hessian = torch.linalg.inv(new_inv_hessian)
    
    # Score function at time t_next
    new_score_value = (new_hessian @ inv_hessian @ score_value[...,None])[...,0]
    # Denoiser mean at time t_next
    new_denoiser_mean = samples + schedule(tnext)**2 * new_score_value
    
    return new_denoiser_cov, new_inv_denoiser_cov, new_hessian, new_inv_hessian, new_score_value, new_denoiser_mean

def update_bfgs(denoiser_cov, inv_denoiser_cov, denoiser_mean_at_x, denoiser_mean_at_xnext, schedule, t, x, dx):
    """
    Update the BFGS approximation of the Hessian and related quantities.

    This function implements the BFGS (Broyden–Fletcher–Goldfarb–Shanno) update
    for approximating the Hessian matrix and its inverse. It also updates related
    quantities such as the denoiser covariance and the score value.

    Args:
        denoiser_cov (torch.Tensor): Current denoiser covariance matrix.
        inv_denoiser_cov (torch.Tensor): Current inverse of denoiser covariance matrix.
        hessian (torch.Tensor): Current Hessian matrix.
        inv_hessian (torch.Tensor): Current inverse of Hessian matrix.
        score_at_t (torch.Tensor): Score value at current point x and time t.
        score_at_tnext (torch.Tensor): Score value at point x+dx and time t.
        denoiser_mean_at_t (torch.Tensor): Denoiser mean at current point x and time t.
        denoiser_mean_at_tnext (torch.Tensor): Denoiser mean at point x+dx and time t.
        schedule (callable): Function that returns the noise level at a given time.
        t (float): Current time.
        x (torch.Tensor): Current point.
        dx (torch.Tensor): Step taken from x to x+dx.

    Returns:
        tuple: A tuple containing:
            - updated_denoiser_cov (torch.Tensor): Updated denoiser covariance matrix.
            - updated_inv_denoiser_cov (torch.Tensor): Updated inverse of denoiser covariance matrix.
            - updated_hessian (torch.Tensor): Updated Hessian matrix.
            - updated_inv_hessian (torch.Tensor): Updated inverse of Hessian matrix.
            - updated_score_value (torch.Tensor): Updated score value.
            - updated_denoiser_mean (torch.Tensor): Updated denoiser mean.

    Note:
        This function assumes that the score_at_t and score_at_tnext are 
        ∇_x log p(x,t) and ∇_x log p(x+dx, t) respectively, i.e., the diffusion 
        time is the same, but the score is evaluated at two different points.
    """
    bs, d = x.shape
    I = torch.eye(d).unsqueeze(0).repeat(bs, 1, 1)  # shape (bs, d, d)
    de = schedule(t)**2 * (denoiser_mean_at_xnext - denoiser_mean_at_x) # shape (bs, d)
    
    gamma = 1/(dx[...,None,:] @ de[...,:,None])
    
    term1 = - denoiser_cov @ dx[...,:,None] @ dx[...,None,:] @ denoiser_cov / (dx[...,None,:] @ denoiser_cov @ dx[...,:,None])
    term2 = de[...,:,None] @ de[...,None,:] * gamma
    # print(gamma)
    # print(term1)
    # print(term2)
    updated_denoiser_cov = denoiser_cov - denoiser_cov @ dx[...,:,None] @ dx[...,None,:] @ denoiser_cov / (dx[...,None,:] @ denoiser_cov @ dx[...,:,None]) + de[...,:,None] @ de[...,None,:] * gamma
    updated_inv_denoiser_cov = (I - dx[...,:,None] @ de[...,None,:] * gamma) @ inv_denoiser_cov @ (I - de[...,:,None] @ dx[...,None,:] * gamma) + dx[...,:,None] @ dx[...,None,:] * gamma
    
    updated_hessian = (updated_denoiser_cov/schedule(t)**2 - I)/schedule(t)**2
    updated_inv_hessian = torch.linalg.inv(updated_hessian + 1e-10*torch.eye(d).unsqueeze(0).repeat(bs, 1, 1))# add a jitter term to make it invertible
    
    return updated_denoiser_cov, updated_inv_denoiser_cov, updated_hessian, updated_inv_hessian

# import torch
# import numpy as np
# from conditioning_utils.covariance_hessian_bfgs import CovarianceHessianBFGS

def test_covariance_hessian_time_update():
    # Define dimensions and batch size
    d = 5
    bs = 1

    # Define a simple score function (e.g., linear function)
    def score_fn(x, t):
        return -x / (t ** 2)

    # Define schedule
    def schedule(t):
        return t

    # Initialize parameters
    t = 20.0
    tnext = 18.0
    x = torch.randn(d)
    # dx = torch.randn(bs, d) * 0.1

    # Initialize dense matrices
    denoiser_cov = torch.eye(d)
    inv_denoiser_cov = torch.eye(d)
    hessian = (denoiser_cov/schedule(t)**2 - torch.eye(d))/schedule(t)**2
    inv_hessian = torch.linalg.inv(hessian)

    # Initialize CovarianceHessianBFGS
    bfgs = CovarianceHessianBFGS(init_denoiser_variance=1, init_noise_variance=schedule(t)**2, data_dim=d)

    # Compute scores and denoiser means
    score_at_t = score_fn(x, t)
    score_at_tnext = score_fn(x, tnext)
    denoiser_mean_at_t = x + (schedule(t) ** 2) * score_at_t
    denoiser_mean_at_tnext = x + (schedule(tnext) ** 2) * score_at_tnext

    # Update dense matrices
    updated_denoiser_cov, updated_inv_denoiser_cov, updated_hessian, updated_inv_hessian, new_score_value, new_denoiser_mean = update_covariance(
        x, denoiser_cov, inv_denoiser_cov, hessian, inv_hessian,
        score_at_t, denoiser_mean_at_t,
        schedule, t, tnext
    )

    # Update BFGS representation
    bfgs.update_time_step(x, schedule(t), schedule(tnext), score_at_t)
    
    # Compare results
    bfgs_denoiser_cov, bfgs_inv_denoiser_cov, bfgs_hessian, bfgs_inv_hessian = bfgs.get_dense_matrices()

    print("Denoiser Covariance Error:", torch.norm(updated_denoiser_cov - bfgs_denoiser_cov).item())
    print("Inverse Denoiser Covariance Error:", torch.norm(updated_inv_denoiser_cov - bfgs_inv_denoiser_cov).item())
    print("Hessian Error:", torch.norm(updated_hessian - bfgs_hessian).item())
    print("Inverse Hessian Error:", torch.norm(updated_inv_hessian - bfgs_inv_hessian).item())
    print("hei")

def test_covariance_hessian_time_update_with_u_and_v():
    # Set random seed for reproducibility
    torch.manual_seed(42)

    for num_u_v_pairs in [1,2,4,8]:
        # Define dimensions
        d = 15

        # Define a simple score function (e.g., linear function)
        def score_fn(x, t):
            return -x / (t ** 2)

        # Define schedule
        def schedule(t):
            return t

        dtype = torch.float32

        # Initialize parameters
        t = 80.0
        tnext = 79.0
        x = torch.randn(d, dtype=dtype)

        # Initialize dense matrices
        denoiser_cov = torch.eye(d, dtype=dtype)
        inv_denoiser_cov = torch.eye(d, dtype=dtype)
        hessian = (denoiser_cov/schedule(t)**2 - torch.eye(d, dtype=dtype))/schedule(t)**2
        inv_hessian = torch.linalg.inv(hessian)

        U = []
        V = []
        for _ in range(num_u_v_pairs):
            # Add vector outer products to denoiser covariance
            u = torch.randn(d, dtype=dtype)
            v = torch.randn(d, dtype=dtype)
            u = u / torch.norm(u)  # Normalize u
            v = v / torch.norm(v)  # Normalize v
        
            # Ensure positive definiteness by making u and v nearly orthogonal
            v = v - torch.dot(u, v) * u
            v = v / torch.norm(v) * np.sqrt(0.5)  # Renormalize v
        
            # Add uu^T - vv^T to denoiser_cov
            denoiser_cov += torch.outer(u, u) - torch.outer(v, v) 
            U.append(u[:,None])
            V.append(v[:,None])

        U = torch.cat(U, dim=1)
        V = torch.cat(V, dim=1)

        # assert that the denoiser covariance is positive definite
        # Check positive definiteness using Cholesky decomposition
        try:
            torch.linalg.cholesky(denoiser_cov)
        except RuntimeError:
            raise ValueError("The denoiser covariance matrix is not positive definite.")

        # Update inv_denoiser_cov and hessian accordingly
        inv_denoiser_cov = torch.linalg.inv(denoiser_cov)
        hessian = (denoiser_cov/schedule(t)**2 - torch.eye(d))/schedule(t)**2
        inv_hessian = torch.linalg.inv(hessian)

        # Initialize CovarianceHessianBFGS
        bfgs = CovarianceHessianBFGS(init_denoiser_variance=1, init_noise_variance=schedule(t)**2, data_dim=d)

        bfgs.vectors_denoiser_cov_u = U
        bfgs.vectors_denoiser_cov_v = V
        bfgs.set_others_corresponding_to_current_denoiser_cov(schedule(t))

        bfgs_denoiser_cov, bfgs_inv_denoiser_cov, bfgs_hessian, bfgs_inv_hessian = bfgs.get_dense_matrices()
        assert torch.norm(bfgs_denoiser_cov - denoiser_cov).item()/d**2 < 1e-8, "Reconstructed denoiser covariance does not match the original."
        assert torch.norm(bfgs_inv_denoiser_cov - inv_denoiser_cov).item()/d**2 < 1e-7, "Reconstructed inverse denoiser covariance does not match the original."
        assert torch.norm(bfgs_hessian - hessian).item()/d**2 < 1e-10, "Reconstructed Hessian does not match the original."
        assert torch.norm(bfgs_inv_hessian - inv_hessian).item()/d**2 < 1e-4, "Reconstructed inverse Hessian does not match the original."

        # Compute scores and denoiser means
        score_at_t = score_fn(x, t)
        score_at_tnext = score_fn(x, tnext)
        denoiser_mean_at_t = x + (schedule(t) ** 2) * score_at_t
        denoiser_mean_at_tnext = x + (schedule(tnext) ** 2) * score_at_tnext

        # Update dense matrices
        updated_denoiser_cov, updated_inv_denoiser_cov, updated_hessian, updated_inv_hessian, new_score_value, new_denoiser_mean = update_covariance(
            x, denoiser_cov, inv_denoiser_cov, hessian, inv_hessian,
            score_at_t, denoiser_mean_at_t,
            schedule, t, tnext
        )

        # Update BFGS representation
        bfgs.update_time_step(x, schedule(t), schedule(tnext), score_at_t)
        
        # Compare results
        bfgs_denoiser_cov, bfgs_inv_denoiser_cov, bfgs_hessian, bfgs_inv_hessian = bfgs.get_dense_matrices()

        print("----------------------------------")
        print("Results for num_u_v_pairs =", num_u_v_pairs)
        print("Denoiser Covariance Error:", torch.norm(updated_denoiser_cov - bfgs_denoiser_cov).item() / d**2)
        print("Inverse Denoiser Covariance Error:", torch.norm(updated_inv_denoiser_cov - bfgs_inv_denoiser_cov).item() / d**2)
        print("Hessian Error:", torch.norm(updated_hessian - bfgs_hessian).item() / d**2)
        print("Inverse Hessian Error:", torch.norm(updated_inv_hessian - bfgs_inv_hessian).item() / d**2)

def test_bfgs_update():
    # Define dimensions
    d = 15

    # Define a simple score function (e.g., linear function)
    def score_fn(x, t):
        term1 = -x / (t ** 2)
        term2 = -0.5 * (x-torch.ones_like(x)) / (t ** 2)
        return 0.7 * term1 + 0.3 * term2

    # Define schedule
    def schedule(t):
        return t

    dtype = torch.float32

    # set random seed
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    np.random.seed(0)

    # Initialize parameters
    t = 50.0
    x = torch.randn(d, dtype=dtype)

    # Initialize dense matrices
    denoiser_cov = torch.eye(d, dtype=dtype)
    inv_denoiser_cov = torch.eye(d, dtype=dtype)
    hessian = (denoiser_cov/schedule(t)**2 - torch.eye(d, dtype=dtype))/schedule(t)**2
    inv_hessian = torch.linalg.inv(hessian)

    # Initialize CovarianceHessianBFGS
    bfgs = CovarianceHessianBFGS(init_denoiser_variance=1, init_noise_variance=schedule(t)**2, data_dim=d)
    bfgs_denoiser_cov, bfgs_inv_denoiser_cov, bfgs_hessian, bfgs_inv_hessian = bfgs.get_dense_matrices()
    assert torch.norm(bfgs_denoiser_cov - denoiser_cov).item()/d**2 < 1e-8, "Reconstructed denoiser covariance does not match the original."
    assert torch.norm(bfgs_inv_denoiser_cov - inv_denoiser_cov).item()/d**2 < 1e-7, "Reconstructed inverse denoiser covariance does not match the original."
    assert torch.norm(bfgs_hessian - hessian).item()/d**2 < 1e-10, "Reconstructed Hessian does not match the original."
    assert torch.norm(bfgs_inv_hessian - inv_hessian).item()/d**2 < 1e-4, "Reconstructed inverse Hessian does not match the original."

    steps = 10

    for _ in range(steps):
        dx = torch.randn(d, dtype=dtype) * 0.1
        xnext = x + dx

        # Compute scores and denoiser means
        score_at_x = score_fn(x, t)
        score_at_xnext = score_fn(xnext, t)
        denoiser_mean_at_x = x + (schedule(t) ** 2) * score_at_x
        denoiser_mean_at_xnext = xnext + (schedule(t) ** 2) * score_at_xnext
        
        # compute bfgs update using the dense matrices
        updated_denoiser_cov, updated_inv_denoiser_cov, updated_hessian, updated_inv_hessian = update_bfgs(denoiser_cov, inv_denoiser_cov, denoiser_mean_at_x[None,:], denoiser_mean_at_xnext[None,:], schedule, t, x[None,:], dx[None,:])
        denoiser_cov, inv_denoiser_cov, hessian, inv_hessian = updated_denoiser_cov, updated_inv_denoiser_cov, updated_hessian, updated_inv_hessian

        # compute bfgs update using the bfgs representation
        bfgs.update_space_step(denoiser_mean_at_x, denoiser_mean_at_xnext, schedule(t), x, xnext)
        bfgs_denoiser_cov, bfgs_inv_denoiser_cov, bfgs_hessian, bfgs_inv_hessian = bfgs.get_dense_matrices()

        print("---------------BFGS update results-----------------")
        print("Denoiser Covariance Error:", torch.norm(updated_denoiser_cov - bfgs_denoiser_cov).item() / d**2)
        print("Inverse Denoiser Covariance Error:", torch.norm(updated_inv_denoiser_cov - bfgs_inv_denoiser_cov).item() / d**2)
        print("Hessian Error:", torch.norm(updated_hessian - bfgs_hessian).item() / d**2)
        print("Inverse Hessian Error:", torch.norm(updated_inv_hessian - bfgs_inv_hessian).item() / d**2)

        x = xnext

def test_time_and_space_updates():
    # Set up parameters
    d = 5  # dimension
    dtype = torch.complex128
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    np.random.seed(0)

    # Initialize parameters
    x = torch.randn(d, dtype=torch.float32).to(dtype)

    # Define score function and schedule
    def score_fn(x, t):
        term1 = -x / (t ** 2)
        term2 = -0.5 * (x-torch.ones_like(x)) / (t ** 2)
        return 0.7 * term1 + 0.3 * term2

    def schedule(t):
        return t

    ts = [50.0, 48.0, 46.0, 44.0, 42.0, 40.0, 38.0, 36.0, 34.0, 32.0, 30.0, 28.0, 26.0, 24.0, 22.0, 20.0, 18.0, 16.0, 14.0, 12.0, 10.0, 8.0, 6.0, 4.0, 2.0, 0.1]
    # ts = [50.0, 30.0, 10.0]
    t = ts[0]

    # Initialize dense matrices
    denoiser_cov = torch.eye(d, dtype=dtype)
    inv_denoiser_cov = torch.eye(d, dtype=dtype)
    hessian = (denoiser_cov/schedule(t)**2 - torch.eye(d, dtype=dtype))/schedule(t)**2
    inv_hessian = torch.linalg.inv(hessian)

    # Initialize CovarianceHessianBFGS
    bfgs = CovarianceHessianBFGS(init_denoiser_variance=1, init_noise_variance=schedule(t)**2, data_dim=d, dtype=dtype)

    for i in range(len(ts)-1):
        t = ts[i]
        tnext = ts[i+1]
        dx = torch.real(torch.randn(d, dtype=dtype) * 0.1).to(dtype)
        xnext = x + dx

        # Perform time update
        score_at_t = score_fn(x, t)
        denoiser_mean_at_t = x + (schedule(t) ** 2) * score_at_t

        updated_denoiser_cov, updated_inv_denoiser_cov, updated_hessian, updated_inv_hessian, new_score_value, new_denoiser_mean = update_covariance(
            x[None, :], denoiser_cov[None, :, :], inv_denoiser_cov[None, :, :], hessian[None, :, :], inv_hessian[None, :, :], 
            score_at_t[None, :], denoiser_mean_at_t[None, :], schedule, t, tnext
        )

        bfgs_new_denoiser_mean, bfgs_new_score_value = bfgs.update_time_step(x, schedule(t), schedule(tnext), score_at_t)

        print("Round ", i)
        print("---------------Time update results-----------------")
        print("Score Value Error:", torch.norm(new_score_value - bfgs_new_score_value).item() / d)
        print("Denoiser Mean Error:", torch.norm(new_denoiser_mean - bfgs_new_denoiser_mean).item() / d)
        bfgs_denoiser_cov, bfgs_inv_denoiser_cov, bfgs_hessian, bfgs_inv_hessian = bfgs.get_dense_matrices()
        print("Denoiser Covariance Error:", torch.norm(updated_denoiser_cov - bfgs_denoiser_cov).item() / d**2)
        print("Inverse Denoiser Covariance Error:", torch.norm(updated_inv_denoiser_cov - bfgs_inv_denoiser_cov).item() / d**2)
        print("Hessian Error:", torch.norm(updated_hessian - bfgs_hessian).item() / d**2)
        print("Inverse Hessian Error:", torch.norm(updated_inv_hessian - bfgs_inv_hessian).item() / d**2)

        # Perform space update
        # dx = torch.real(torch.randn(d, dtype=dtype) * 0.1).to(dtype)
        # xnext = x + dx

        score_at_xnext = score_fn(xnext, tnext)
        denoiser_mean_at_xnext = xnext + (schedule(tnext) ** 2) * score_at_xnext

        updated_denoiser_cov, updated_inv_denoiser_cov, updated_hessian, updated_inv_hessian = update_bfgs(
            updated_denoiser_cov[0], updated_inv_denoiser_cov[0], denoiser_mean_at_t, denoiser_mean_at_xnext[None,:], 
            schedule, tnext, x[None,:], dx[None,:]
        )

        bfgs.update_space_step(denoiser_mean_at_t, denoiser_mean_at_xnext, schedule(tnext), x, xnext)
        bfgs_denoiser_cov, bfgs_inv_denoiser_cov, bfgs_hessian, bfgs_inv_hessian = bfgs.get_dense_matrices()

        print("---------------Space update results-----------------")
        print("Denoiser Covariance Error:", torch.norm(updated_denoiser_cov - bfgs_denoiser_cov).item() / d**2)
        print("Inverse Denoiser Covariance Error:", torch.norm(updated_inv_denoiser_cov - bfgs_inv_denoiser_cov).item() / d**2)
        print("Hessian Error:", torch.norm(updated_hessian - bfgs_hessian).item() / d**2)
        print("Inverse Hessian Error:", torch.norm(updated_inv_hessian - bfgs_inv_hessian).item() / d**2)

        # State updates for the dense matrices
        denoiser_cov, inv_denoiser_cov, hessian, inv_hessian = updated_denoiser_cov, updated_inv_denoiser_cov, updated_hessian, updated_inv_hessian

        x = xnext

if __name__ == "__main__":
    # test_covariance_hessian_time_update()
    # test_covariance_hessian_time_update_with_u_and_v()
    # test_bfgs_update()
    test_time_and_space_updates()