import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
from matplotlib.ticker import AutoMinorLocator, FormatStrFormatter

plt.style.use('seaborn-v0_8-paper')  
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['DejaVu Serif', 'Times New Roman', 'Liberation Serif'],
    'font.size': 12,  
    'axes.labelsize': 12, 
    'axes.titlesize': 14,  
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,  
    'legend.fontsize': 10, 
    'figure.dpi': 300,  
    'savefig.dpi': 300, 
    'savefig.bbox': 'tight',  
    'savefig.pad_inches': 0.1 
})

def compute_weighted_sum(z, weights, norm):
    mask = (z > 1e-30).float()
    weighted_part = weights / z.clamp(min=1e-8) 
    min_cols = torch.argmin(norm, dim=1, keepdim=True)
    sparse_part = torch.zeros_like(weights)
    sparse_part.scatter_(1, min_cols, 1)
    weighted_sum = mask * weighted_part + (1 - mask) * sparse_part
    
    return weighted_sum


def f(t, X, Y):
    
    alpha = t
    beta = 1-t
    #gamma = math.sqrt(t*(1-t))
    gamma = t*(1-t)

    alpha_dot = 1
    beta_dot = -1
    #gamma_dot = (1-2*t)/math.sqrt(t*(1-t))/2
    gamma_dot = 1-2*t

    B_t = beta*(beta*alpha_dot-beta_dot*alpha) + gamma*(gamma*alpha_dot-gamma_dot*alpha)
    a1 = beta*(beta*alpha_dot-beta_dot*alpha)
    a2 = gamma*(gamma*alpha_dot-gamma_dot*alpha)

    C1 = gamma*gamma_dot + beta*beta_dot
    C2 = (gamma*gamma_dot + beta*beta_dot)*alpha - (gamma**2 + beta**2)*alpha_dot
    C3 = gamma**2 + beta**2

    a = C2/C1
    b= C1/C3
    
    diff = (C1*X.unsqueeze(1) - C2*Y.unsqueeze(0) )/C3  # (m, n, d)
    norm_sq = torch.sum((X.unsqueeze(1) - alpha * Y.unsqueeze(0)) ** 2, dim=2) / (2 * C3)  # (m, n)
    weights = F.softmax(-norm_sq, dim=1)
    

    velocity = torch.sum(diff * weights.unsqueeze(-1), dim=1)  # (m, d)
    score = alpha/B_t*velocity - alpha_dot/B_t*X

    if t>0.9:
        aa= 1
    
    return velocity,score,B_t

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)

def generate_data(num_samples):
    X = torch.rand(num_samples, 2)*2-1

    return X

def sample(data,num_generate):
    timesteps = 100
   
    train_sample = torch.randn(num_generate, 2)  
    x_t = train_sample
    #x_t = torch.rand(num_generate, 2)*4-2
    for t in (range(1,timesteps)):

        t_in = t/timesteps
        
        eta = 0.1   #Control the noise intensity during generation. If it is set to 0, the generation becomes deterministic.

        velocity, score,A_t = f(t_in,x_t,data)
        epsilon_t = math.sqrt(t_in*(1-t_in))*eta
        #epsilon_t = (t_in*(1-t_in)**2)*eta

        noise = torch.randn_like(x_t)

        x_t = x_t + (velocity+epsilon_t*score)/timesteps + math.sqrt(2*epsilon_t)*noise

    return x_t.detach().numpy(),eta

num_samples = 5
num_generate = 2000
seed = 42
set_seed(seed)
data = generate_data(num_samples)  #traing set

generated_samples,eta = sample(data,num_generate)

plt.scatter(generated_samples[:, 0], generated_samples[:, 1], color='#d62728', label='Generated Samples', alpha=0.5)
plt.scatter(data[:, 0], data[:, 1], color='#1f77b4', label='Training Data', alpha=1)


plt.legend()
name = 'simulation/oracle,eta=' + str(eta) + '.jpg'

plt.savefig(name)