import numpy as np
from scipy.spatial.distance import pdist, squareform


class SVGD:
    def __init__(self, kernel_bandwidth=None):
        self.kernel_bandwidth = kernel_bandwidth
        self.mmd_history = []

    def rbf_kernel(self, X, Y=None):
        """
        Compute the RBF kernel and its gradient.
        """
        if Y is None:
            Y = X
        pairwise_dists = squareform(pdist(X, metric='euclidean')) ** 2 if X is Y else np.sum(X**2, axis=1).reshape(-1, 1) + np.sum(Y**2, axis=1) - 2 * np.dot(X, Y.T)
        if self.kernel_bandwidth is None:
            self.kernel_bandwidth = np.median(pairwise_dists) / np.log(X.shape[0] + 1)
        bandwidth = self.kernel_bandwidth
        K = np.exp(-pairwise_dists / (2 * bandwidth))
        return K

    def compute_mmd(self, particles, target_samples):
        """
        Compute the Maximum Mean Discrepancy (MMD) between particles and target samples.
        """
        K_XX = self.rbf_kernel(particles, particles)
        K_YY = self.rbf_kernel(target_samples, target_samples)
        K_XY = self.rbf_kernel(particles, target_samples)

        mmd = np.mean(K_XX) + np.mean(K_YY) - 2 * np.mean(K_XY)
        return mmd

    def update(self, X, grad_log_prob, target_samples, lr=1e-3, iterations=100):
        """
        Update particles using SVGD.
        :param X: Initial particles (N, D)
        :param grad_log_prob: Function to compute the gradient of the log density
        :param target_samples: Samples from the target distribution for MMD calculation
        :param lr: Learning rate
        :param iterations: Number of iterations
        """
        for i in range(iterations):
            # Compute gradients of log probability
            grad_log_p = grad_log_prob(X)
            # Compute kernel and its gradient
            K, grad_K = self.rbf_kernel(X), -np.matmul(self.rbf_kernel(X), X) + X * np.sum(self.rbf_kernel(X), axis=1, keepdims=True)
            grad_K /= self.kernel_bandwidth
            # Compute SVGD gradient
            svgd_grad = (np.matmul(K, grad_log_p) + grad_K) / X.shape[0]
            # Update particles
            X += lr * svgd_grad
            # Compute MMD
            mmd = self.compute_mmd(X, target_samples)
            self.mmd_history.append(mmd)
        return X
class NSVGD:
    def __init__(self, lam= 1, kernel_bandwidth=None):
        self.kernel_bandwidth = kernel_bandwidth
        self.mmd_history = []
        self.lam=lam
    def rbf_kernel(self, X, Y=None):
        """
        Compute the RBF kernel and its gradient.
        """
        if Y is None:
            Y = X
        pairwise_dists = squareform(pdist(X, metric='euclidean')) ** 2 if X is Y else np.sum(X**2, axis=1).reshape(-1, 1) + np.sum(Y**2, axis=1) - 2 * np.dot(X, Y.T)
        if self.kernel_bandwidth is None:
            self.kernel_bandwidth = np.median(pairwise_dists) / np.log(X.shape[0] + 1)
        bandwidth = self.kernel_bandwidth
        K = np.exp(-pairwise_dists / (2 * bandwidth))
        return K

    def compute_mmd(self, particles, target_samples):
        """
        Compute the Maximum Mean Discrepancy (MMD) between particles and target samples.
        """
        K_XX = self.rbf_kernel(particles, particles)
        K_YY = self.rbf_kernel(target_samples, target_samples)
        K_XY = self.rbf_kernel(particles, target_samples)

        mmd = np.mean(K_XX) + np.mean(K_YY) - 2 * np.mean(K_XY)
        return mmd

    def update(self, X, grad_log_prob, target_samples, lr=1e-3, iterations=100):
        """
        Update particles using SVGD.
        :param X: Initial particles (N, D)
        :param grad_log_prob: Function to compute the gradient of the log density
        :param target_samples: Samples from the target distribution for MMD calculation
        :param lr: Learning rate
        :param iterations: Number of iterations
        """
        for i in range(iterations):
            # Compute gradients of log probability
            grad_log_p = grad_log_prob(X)
            # Compute kernel and its gradient
            K, grad_K = self.rbf_kernel(X), -np.matmul(self.rbf_kernel(X), X) + X * np.sum(self.rbf_kernel(X), axis=1, keepdims=True)
            grad_K /= self.kernel_bandwidth
            # Compute SVGD gradient
            svgd_grad = (np.matmul(K, grad_log_p) + grad_K) / X.shape[0]
            # Update particles
            noise = np.random.normal(0, 1, size=X.shape)
            X += 1*lr * svgd_grad +self.lam*lr*grad_log_p + np.sqrt(2*lr*self.lam)*noise
            # Compute MMD
            mmd = self.compute_mmd(X, target_samples)
            self.mmd_history.append(mmd)
        return X
class LangevinDynamics:
    def __init__(self, step_size=1e-3):
        self.step_size = step_size
        self.mmd_history = []

    def compute_mmd(self, particles, target_samples, rbf_kernel):
        K_XX = rbf_kernel(particles, particles)
        K_YY = rbf_kernel(target_samples, target_samples)
        K_XY = rbf_kernel(particles, target_samples)
        mmd = np.mean(K_XX) + np.mean(K_YY) - 2 * np.mean(K_XY)
        return mmd

    def update(self, X, grad_log_prob, target_samples, rbf_kernel, iterations=100):
        for i in range(iterations):
            noise = np.random.normal(0, 1, size=X.shape)
            X += self.step_size * grad_log_prob(X) + np.sqrt(2*self.step_size)*noise
            mmd = self.compute_mmd(X, target_samples, rbf_kernel)
            self.mmd_history.append(mmd)
        return X

# Example usage
if __name__ == "__main__":
    import matplotlib.pyplot as plt


    var = 3

    def sample_mixture_distribution(n_samples):
        samples = np.zeros((n_samples,2))
        samples[:,0]=np.random.randn(n_samples)*np.sqrt(var)
        for i in range(n_samples):
            samples[i,1] = np.random.randn(1)*np.exp(samples[i,0]/2)
        return samples


    def prob(x):
        log =- x[0]**2/(2*var) - 0.5 * (x[1]**2) * np.exp(-x[0]) -  x[0] / 2  
        return np.exp(log)/(2*np.pi*np.exp(x[0]/2)*np.sqrt(3))
    def grad_log_prob(x):

        gradients = np.zeros_like(x)

        for i in range(np.shape(x)[0]):

            gradients[i,0] =  -x[i,0]/var + 0.5 *(x[i,1]**2) * np.exp(-x[i,0]) - 1/2
            gradients[i,1] = -x[i,1] * np.exp(-x[i,0])
        return gradients
    def log_prob(X):
        mean = np.array([0, 0])
        cov = np.array([[1, 0.8], [0.8, 1]])
        inv_cov = np.linalg.inv(cov)
        diff = X - mean
        return -0.5 * np.sum(diff @ inv_cov * diff, axis=1)


    # np.random.seed(0)
    # target_samples = sample_mixture_distribution(100)
    def MMD():
        svgd_mmd_all = []
        nsvgd1_mmd_all = []
        nsvgd2_mmd_all = []
        nsvgd3_mmd_all = []
        langevin_mmd_all = []
        iterations = 1000
        runs = 50

        for _ in range(runs):
            target_samples = sample_mixture_distribution(500)
            particles_svgd = np.random.randn(100, 2)
            particles_nsvgd1 = np.copy( particles_svgd)
            particles_nsvgd2 = np.copy( particles_svgd)
            particles_nsvgd3 = np.copy( particles_svgd)
            particles_langevin = np.copy( particles_svgd)

            svgd = SVGD()
            svgd.update(particles_svgd, grad_log_prob, target_samples, lr=0.1, iterations=iterations)
            svgd_mmd_all.append(svgd.mmd_history)

            nsvgd1 = NSVGD(1)
            nsvgd1.update(particles_nsvgd1, grad_log_prob, target_samples, lr=0.1, iterations=iterations)
            nsvgd1_mmd_all.append(nsvgd1.mmd_history)

            nsvgd2 = NSVGD(0.1)
            nsvgd2.update(particles_nsvgd2, grad_log_prob, target_samples, lr=0.1, iterations=iterations)
            nsvgd2_mmd_all.append(nsvgd2.mmd_history)
                    
            nsvgd3 = NSVGD(0.01)
            nsvgd3.update(particles_nsvgd3, grad_log_prob, target_samples, lr=0.1, iterations=iterations)
            nsvgd3_mmd_all.append(nsvgd3.mmd_history)

            langevin = LangevinDynamics(step_size=0.1)
            langevin.update(particles_langevin, grad_log_prob, target_samples, svgd.rbf_kernel, iterations=iterations)
            langevin_mmd_all.append(langevin.mmd_history)

        # Compute mean MMD over runs
        svgd_mmd_mean = np.mean(svgd_mmd_all, axis=0)
        nsvgd1_mmd_mean = np.mean(nsvgd1_mmd_all, axis=0)
        nsvgd2_mmd_mean = np.mean(nsvgd2_mmd_all, axis=0)
        nsvgd3_mmd_mean = np.mean(nsvgd3_mmd_all, axis=0)
        langevin_mmd_mean = np.mean(langevin_mmd_all, axis=0)

        svgd_mmd_std = np.std(svgd_mmd_all, axis=0)*1.96/np.sqrt(50)
        nsvgd1_mmd_std = np.std(nsvgd1_mmd_all, axis=0)*1.96/np.sqrt(50)
        nsvgd2_mmd_std = np.std(nsvgd2_mmd_all, axis=0)*1.96/np.sqrt(50)
        nsvgd3_mmd_std = np.std(nsvgd3_mmd_all, axis=0)*1.96/np.sqrt(50)
        langevin_mmd_std = np.std(langevin_mmd_all, axis=0)*1.96/np.sqrt(50)

        colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
        plt.figure()
        indices = np.linspace(0, len(svgd_mmd_mean) - 1, 10, dtype=int)
        plt.plot(svgd_mmd_mean,color=colors[0])
        plt.fill_between(range(len(svgd_mmd_mean)), 
                        svgd_mmd_mean - svgd_mmd_std, 
                        svgd_mmd_mean + svgd_mmd_std, 
                        alpha=0.2,color=colors[0])
        plt.scatter(indices, svgd_mmd_mean[indices], color=colors[0], marker='o', label="SVGD (NSVGD with λ=0)")

        # NSVGD λ=1
        plt.plot(nsvgd1_mmd_mean,color=colors[1])
        plt.fill_between(range(len(nsvgd1_mmd_mean)), 
                        nsvgd1_mmd_mean - nsvgd1_mmd_std, 
                        nsvgd1_mmd_mean + nsvgd1_mmd_std, 
                        alpha=0.2,color=colors[1])
        plt.scatter(indices, nsvgd1_mmd_mean[indices],color=colors[1], marker='s', label="NSVGD with λ=1")

        # NSVGD λ=0.1
        plt.plot(nsvgd2_mmd_mean,color=colors[2])
        plt.fill_between(range(len(nsvgd2_mmd_mean)), 
                        nsvgd2_mmd_mean - nsvgd2_mmd_std, 
                        nsvgd2_mmd_mean + nsvgd2_mmd_std, 
                        alpha=0.2,color=colors[2])
        plt.scatter(indices, nsvgd2_mmd_mean[indices],color=colors[2], marker='^', label="NSVGD with λ=0.1")

        # NSVGD λ=0.01
        plt.plot(nsvgd3_mmd_mean,color=colors[3])
        plt.fill_between(range(len(nsvgd3_mmd_mean)), 
                        nsvgd3_mmd_mean - nsvgd3_mmd_std, 
                        nsvgd3_mmd_mean + nsvgd3_mmd_std, 
                        alpha=0.2,color=colors[3])
        plt.scatter(indices, nsvgd3_mmd_mean[indices],color=colors[3], marker='v', label="NSVGD with λ=0.01")

        plt.plot(langevin_mmd_mean,color=colors[4])
        plt.fill_between(range(len(langevin_mmd_mean)), 
                        langevin_mmd_mean - langevin_mmd_std, 
                        langevin_mmd_mean + langevin_mmd_std, 
                        alpha=0.2,color=colors[4])
        plt.scatter(indices, langevin_mmd_mean[indices],color=colors[4], marker='D', label="Langevin Dynamics")


        plt.title("Mean (over 50 runs) MMD Evolution Over Iterations with Variance")
        plt.grid()
        plt.xlabel("Iteration")
        plt.ylabel("MMD")
        plt.legend()

        plt.savefig("figure_numpy_with_variance.pdf")
        plt.savefig("comparaison_with_variance.png")
        plt.show()
    def plot():

        target_samples = sample_mixture_distribution(100)
        particles_svgd = np.random.randn(500, 2)
        particles_lan = np.copy(particles_svgd)


        svgd = SVGD()
        svgd_samples=svgd.update(particles_svgd, grad_log_prob, target_samples, lr=0.1, iterations=500)
        

        lan = LangevinDynamics(0.1)
        langevin_samples=lan.update(particles_lan, grad_log_prob, target_samples, svgd.rbf_kernel, 500)

        plt.figure(figsize=(10, 8))

        plt.scatter(target_samples[:, 0], target_samples[:, 1], label="Target Samples", alpha=0.3, s=10)

        plt.scatter(svgd_samples[:, 0], svgd_samples[:, 1], label="SVGD Samples", alpha=0.8, s=10)

        plt.scatter(langevin_samples[:, 0], langevin_samples[:, 1], label="Langevin Samples", alpha=0.8, s=10)

        plt.title("Samples of SVGD and Langevin with respect to the target")
        plt.xlabel("x1")
        plt.ylabel("x2")
        plt.legend()
        plt.grid(alpha=0.3)


        plt.savefig("samples_comparison.pdf")
        plt.savefig("samples_comparison.png")
        
    MMD()