import numpy as np
import matplotlib.pyplot as plt
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm

def K_RBF(x, y):
    """Radial Basis Function (RBF) kernel."""
    sigma = 1.0
    return np.exp(-np.linalg.norm(x - y) ** 2 / (2 * sigma**2))

def K_IMQ(x, y):
    """Inverse Multiquadric (IMQ) kernel."""
    c = 1  # Scale parameter
    beta = 0.5  # Shape parameter
    sigma = 1.0
    return 1 / (c**2 + np.linalg.norm(x - y)**2 / (2 * sigma**2))**beta

def grad_F(x):
    """Gradient of the function F at point x."""
    return x  # Example: For F(x) = 1/2 * ||x||^2, the gradient is x

def grad2_K_RBF(x, y):
    """Gradient with respect to the second argument of the RBF kernel K(x, y)."""
    sigma = 1.0
    return (x - y) * K_RBF(x, y) / (sigma**2)

def grad2_K_IMQ(x, y):
    """Gradient with respect to the second argument of the IMQ kernel K(x, y)."""
    c = 1  # Scale parameter
    beta = 0.5  # Shape parameter
    sigma = 1.0
    norm_x_y_squared = np.linalg.norm(x - y)**2
    return -beta * (c**2 + norm_x_y_squared / (2 * sigma**2))**(-beta-1) * (y - x) / (sigma**2)

def noisy_svgd(X, grad_F, gamma, lambda_, n_iter, d, kernel):
    """Noisy Stein Variational Gradient Descent algorithm."""
    n = X.shape[0]
    for _ in range(n_iter):
        gamma_k = gamma[_]
        noise = np.random.normal(size=(n, d))
        X_new = np.zeros_like(X)
        
        for i in range(n):
            grad_log_pi = np.zeros(d)
            for j in range(n):
                if kernel == 'RBF':
                    K_val = K_RBF(X[i], X[j])
                    grad_K = grad2_K_RBF(X[i], X[j])
                elif kernel == 'IMQ':
                    K_val = K_IMQ(X[i], X[j])
                    grad_K = grad2_K_IMQ(X[i], X[j])
                
                grad_log_pi += K_val * grad_F(X[j]) - grad_K
            grad_log_pi /= n
            X_new[i] = X[i] - gamma_k * grad_log_pi - lambda_ * gamma_k * grad_F(X[i]) + np.sqrt(2 * lambda_ * gamma_k) * noise[i]
        
        X = X_new
    return X


def langevin(X, grad_F, gamma, n_iter, d):
    """Langevin algorithm."""
    n = X.shape[0]
    for _ in range(n_iter):
        gamma_k = gamma[_]
        noise = np.random.normal(size=(n, d))
        X_new = np.zeros_like(X)
        
        for i in range(n):
            X_new[i] = X[i] - gamma_k * grad_F(X[i]) + np.sqrt(2 * gamma_k) * noise[i]
        
        X = X_new
    return X



def compute_variance(n, n_iter, gamma, kernel, lambda_, d, runs):
    results = []
    for _ in range(runs):  # Perform 'runs' number of runs
        X_init = np.random.randn(n, d)
        var = noisy_svgd(X_init, grad_F, gamma, lambda_, n_iter, d, kernel)
        variances = np.var(var, axis=0)
        results.append(np.mean(variances))
    return (d, np.mean(results), np.std(results))

def generate_plots(kernel):
    n_values = [10, 20, 30]
    n_iter = 200
    dimensions = [100, 200, 300, 400, 500]
    gamma = [10 / (i + 1) for i in range(n_iter)] # 1.0 / (i + 1)
    lambdas = [0, 2]

    plt.figure(figsize=(10, 6))
    colors = ['b', '#FFA500', 'g']  # Blue, Orange, Green

    linestyles = [':', '-']  # Dotted for lambda=0, solid for lambda=1.0
    i=0
    for lambda_ in lambdas:
        linestyle = linestyles[i]
        i+=1
        for idx, n in enumerate(n_values):
            runs = 10 if lambda_ == 0 else 10  # For lambda=0, run once; for lambda=1.0, run 10 times
            with ProcessPoolExecutor() as executor:
                futures = [executor.submit(compute_variance, n, n_iter, gamma, kernel, lambda_, d, runs) for d in dimensions]
                results = list(tqdm((future.result() for future in futures), total=len(dimensions), desc=f'Processing dimensions for λ={lambda_}, n={n}'))
                dimensions_sorted, means, stds = zip(*sorted(results))
                color = colors[idx]
                if lambda_ == 0:
                    plt.plot(dimensions_sorted, means, marker='o', linestyle=linestyle, color=color, label=f'$\lambda$ = {lambda_}, n = {n}')
                else:
                    plt.plot(dimensions_sorted, means, marker='o', linestyle=linestyle, color=color, label=f'$\lambda$ = {lambda_}, n = {n}')
                    plt.fill_between(dimensions_sorted, np.array(means) - np.array(stds), np.array(means) + np.array(stds), color=color, alpha=0.2)
    
    plt.axhline(y=1, color='r', linestyle='--', label='Ground truth (DAMV = 1)')
    plt.xlabel('Dimension (d)')
    plt.ylabel('Dimension-averaged Marginal Variance')
    plt.title(f'Dimension-averaged Marginal Variance vs. Dimension with {kernel} Kernel')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'{kernel}_variance_plot.pdf')
    plt.show()

def main():
    kernels = ['RBF', 'IMQ']
    for kernel in kernels:
        generate_plots(kernel)

if __name__ == "__main__":
    main()
