import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
from scipy.stats import multivariate_normal

plt.style.use('seaborn-v0_8-paper')  
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['DejaVu Serif', 'Times New Roman', 'Liberation Serif'], 
    'font.size': 12,  
    'axes.labelsize': 12,  
    'axes.titlesize': 14, 
    'xtick.labelsize': 10,  
    'ytick.labelsize': 10,  
    'legend.fontsize': 10,  
    'figure.dpi': 300, 
    'savefig.dpi': 300,  
    'savefig.bbox': 'tight', 
    'savefig.pad_inches': 0.1  
})


def compute_mean_pdf(x, M, Sigma):
    p = Sigma.shape[0]
    x = np.asarray(x)
    M = np.asarray(M)
    Sigma = np.asarray(Sigma)
    
    Sigma_inv = np.linalg.inv(Sigma)
    Sigma_det = np.linalg.det(Sigma)
    diff = x - M  # (n × p)
    exponent = -0.5 * np.einsum('ij,jk,ik->i', diff, Sigma_inv, diff)  # (n,)
    normalization = 1.0 / (np.power(2 * np.pi, p/2) * np.sqrt(Sigma_det))
    pdf_values = normalization * np.exp(exponent)
    return np.mean(pdf_values)

def error_fun3(x_t,t,train_sample,gamma_type):
 
    alpha = t
    beta = 1-t
    if gamma_type == 1:
        gamma = t*(1-t)
        gamma_dot = 1-2*t
    elif gamma_type == 2:
        gamma = math.sqrt(t*(1-t))
        gamma_dot = (1-2*t)/2/math.sqrt(t*(1-t))
    elif gamma_type == 3:
        gamma = t*(1-t)**2
        gamma_dot = 3*t**2-4*t+1


    C3 = beta**3*gamma**2 + beta

    mu_list = alpha * train_sample
    Sigma = C3*beta *np.eye(2)
    noise = torch.rand_like(x_t)
    #pdf_values = np.array([multivariate_normal(mean=mu, cov=Sigma).pdf(x_t) for mu in mu_list])
    error = 1/compute_mean_pdf(x_t, mu_list, Sigma) *noise

    return error

def compute_weighted_sum(z, weights, norm):
    mask = (z > 1e-30).float()
    
    weighted_part = weights / z.clamp(min=1e-8) 
    min_cols = torch.argmin(norm, dim=1, keepdim=True)
    sparse_part = torch.zeros_like(weights)
    sparse_part.scatter_(1, min_cols, 1) 
    weighted_sum = mask * weighted_part + (1 - mask) * sparse_part
    
    return weighted_sum

def f(t, X, Y,gamma_type):
    
    alpha = t
    beta = 1-t 
    alpha_dot = 1
    beta_dot = -1
    if gamma_type == 1:
        gamma = t*(1-t)
        gamma_dot = 1-2*t
    elif gamma_type == 2:
        gamma = math.sqrt(t*(1-t))
        gamma_dot = (1-2*t)/2/math.sqrt(t*(1-t))
    elif gamma_type == 3:
        gamma = t*(1-t)**2
        gamma_dot = 3*t**2-4*t+1


    A_t = alpha*(alpha*beta_dot-alpha_dot*beta) + gamma*(gamma*beta_dot-gamma_dot*beta)

    C1 = gamma*gamma_dot + beta*beta_dot
    C2 = (gamma*gamma_dot + beta*beta_dot)*alpha - (gamma**2 + beta**2)*alpha_dot
    C3 = gamma**2 + beta**2
    
    diff = (C1*X.unsqueeze(1) - C2*Y.unsqueeze(0) )/C3  # (m, n, d)
    norm_sq = torch.sum((X.unsqueeze(1) - alpha * Y.unsqueeze(0)) ** 2, dim=2) / (2 * (beta**2+gamma**2))  # (m, n)
    weights = F.softmax(-norm_sq, dim=1)

    velocity = torch.sum(diff * (weights ).unsqueeze(-1), dim=1)  # (m, d)
    score = beta/A_t*velocity - beta_dot/A_t*X
    
    return velocity,score

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)

def generate_data(num_samples):
    X = torch.rand(num_samples, 2)*2-1

    return X

def sample(data,num_generate,error_type,gamma_type,eta):
    timesteps = 200
   
    train_sample = torch.randn(num_generate, 2)   
    x_t = train_sample
    #x_t = torch.rand(num_generate, 2)*4-2
    for t in (range(1,timesteps)):

        t_in = t/timesteps
        if gamma_type == 1:
            gamma = t_in*(1-t_in)
        elif gamma_type == 2:
            gamma = math.sqrt(t_in*(1-t_in))
        elif gamma_type == 3:
            gamma = t_in*(1-t_in)**2
        
        if error_type == 1:
            error_level1 = 1   #Control the estimation error of each function.   timesteps = 200  
            error_level2 = 1    #Control the estimation error of each function.
            noise1 = torch.randn_like(x_t)
            noise2 = torch.randn_like(x_t)
            error1 = error_level1*(torch.sin(x_t*5)*torch.cos(x_t*2)+torch.sin(x_t*3))*noise1/torch.norm(noise1)
            error2 = error_level2*(torch.sin(x_t*5)*torch.cos(x_t*2)+torch.sin(x_t*3))*noise2/torch.norm(noise2)
            
        
        elif error_type == 2:
            error_level1 = 1      #timesteps = 20000
            error_level2 = 1     
            noise1 = torch.randn_like(x_t)
            noise2 = torch.randn_like(x_t)
            error1 = error_level1*(torch.sin(x_t*5)*torch.cos(x_t*2)+torch.sin(x_t*3))/gamma*noise1/torch.norm(noise1)
            error2 = error_level2*(torch.sin(x_t*5)*torch.cos(x_t*2)+torch.sin(x_t*3))/gamma*noise2/torch.norm(noise2)
            
        elif error_type == 3: 
            error_level1 = 1e-2           #timesteps = 200
            error_level2 = 1     
            error1 = error_level1 * error_fun3(x_t,t_in,train_sample,gamma_type)
            error2 = error_level2 * error_fun3(x_t,t_in,train_sample,gamma_type)

        velocity, score = f(t_in,x_t,data,gamma_type)

        epsilon_t = math.sqrt(t_in*(1-t_in))*eta
        velocity = velocity + error1
        score = score + error2


        noise = torch.randn_like(x_t)

        x_t = x_t + (velocity+epsilon_t*score)/timesteps + math.sqrt(2*epsilon_t)*noise

    return x_t.detach().numpy(), error_level1, error_level2, eta


num_samples = 5
num_generate = 1000
seed = 33
set_seed(seed)
data = generate_data(num_samples)      #training set
error_type = 3                       #1 is global control, 2 is divided by gamma control, 3 is inversely proportional to p(x_t)
gamma_type = 1                      #1 is of the same order, 2 is convergent, 3 is divergent.
eta =  0                              #Control the noise intensity during generation. If it is 0, the generation becomes deterministic



generated_samples,error_level1, error_level2, eta = sample(data,num_generate,error_type,gamma_type,eta)

plt.scatter(generated_samples[:, 0], generated_samples[:, 1], color='#d62728', label='Generated Samples', alpha=0.5)
plt.scatter(data[:, 0], data[:, 1], color='#1f77b4', label='Training Data', alpha=1)

plt.legend(loc='upper right')


name = 'simulation/error_type=' + str(error_type) + ',gamma=' + str(gamma_type) + ',error1=' +str(error_level1) + ',error2=' + str(error_level2)+ ',eta=' + str(eta) +  '.jpg'
plt.savefig(name)



