import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os

plt.style.use('seaborn-v0_8-paper')  
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['DejaVu Serif', 'Times New Roman', 'Liberation Serif'],
    'font.size': 12,  
    'axes.labelsize': 12, 
    'axes.titlesize': 14,  
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,  
    'legend.fontsize': 10, 
    'figure.dpi': 300,  
    'savefig.dpi': 300, 
    'savefig.bbox': 'tight',  
    'savefig.pad_inches': 0.1 
})

def compute_weighted_sum(z, weights, norm):
    mask = (z > 1e-30).float()
    weighted_part = weights / z.clamp(min=1e-8) 
    min_cols = torch.argmin(norm, dim=1, keepdim=True)
    sparse_part = torch.zeros_like(weights)
    sparse_part.scatter_(1, min_cols, 1)
    weighted_sum = mask * weighted_part + (1 - mask) * sparse_part
    return weighted_sum

def f(t, X, Y):
    alpha = t
    beta = 1-t
    gamma = math.sqrt(t*(1-t))

    alpha_dot = 1
    beta_dot = -1
    gamma_dot = (1-2*t)/math.sqrt(t*(1-t))/2

    A_t = beta*(beta*alpha_dot-beta_dot*alpha) + gamma*(gamma*alpha_dot-gamma_dot*alpha)

    C1 = beta**3*gamma*gamma_dot - beta_dot
    C2 = (beta**3*gamma*gamma_dot - beta_dot)*alpha - (beta**3*gamma**2+beta)*alpha_dot
    C3 = beta**3*gamma**2 + beta
    
    diff = -(C1*X.unsqueeze(1) - C2*Y.unsqueeze(0)) / C3
    norm_sq = torch.sum((X.unsqueeze(1) - alpha * Y.unsqueeze(0)) ** 2, dim=2) / (2 * C3*beta)
    weights = F.softmax(-norm_sq, dim=1)
    
    velocity = torch.sum(diff * weights.unsqueeze(-1), dim=1)
    score = alpha/A_t*velocity - alpha_dot/A_t*X
    
    return velocity, score

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)

def generate_data(num_samples):
    return torch.rand(num_samples, 2)*2-1

def sample(data, num_generate,eta):
    timesteps = 100
    x_t = torch.randn(num_generate, 2)
    
    for t in range(1, timesteps):
        t_in = t/timesteps
        
        velocity, score = f(t_in, x_t, data)
        epsilon_t = math.sqrt(t_in*(1-t_in))*eta
        noise = torch.randn_like(x_t)

        x_t = x_t + (velocity+epsilon_t*score)/timesteps + math.sqrt(2*epsilon_t)*noise

    return x_t.detach().numpy(), eta

# Main execution
num_samples = 5
num_generate = 2000
seed = 42
set_seed(seed)
data = generate_data(num_samples)
data_np = data.numpy()

eta = 0.1

generated_samples, eta = sample(data, num_generate,eta)

# Calculate the variance
t_values = torch.linspace(0, 1, 1000)
variance = torch.mean(   2*torch.sqrt(t_values * (1 - t_values)) *eta   ).item()

# Create figure
fig, ax = plt.subplots(figsize=(8, 6))

# Generate points from normal distributions
num_samples_per_point = 200
generated_normal_points = []
for point in data_np:
    std_normal = np.random.randn(num_samples_per_point, 2)  
    scaled = std_normal * np.sqrt(variance)  
    samples = scaled + point  
    
    generated_normal_points.append(samples)
generated_normal_points = np.vstack(generated_normal_points)

# Plot all points
ax.scatter(generated_samples[:, 0], generated_samples[:, 1], color='#d62728', 
          label='Generated Samples', alpha=0.5, s=10)
ax.scatter(data_np[:, 0], data_np[:, 1], color='#1f77b4', 
          label='Training Data', alpha=1, s=50, edgecolor='k')
ax.scatter(generated_normal_points[:, 0], generated_normal_points[:, 1],
          color='#2ca02c', label=f'Normal Samples (σ²={variance:.3f})', alpha=0.3, s=10)

plt.legend()

# Save figure
os.makedirs('simulation', exist_ok=True)
plt.savefig(f'simulation/oracle_eta={eta}.jpg')
plt.close()