import numpy as np
import matplotlib.pyplot as plt
import torch

# ----- Matplotlib Global Config -----
plt.rcParams.update({
    "font.size": 9,
    "axes.titlesize": 8,
    "axes.labelsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "legend.fontsize": 4,
    "figure.dpi": 600,
    "savefig.dpi": 600,
    "axes.linewidth": 0.8,
    "xtick.major.width": 0.8,
    "ytick.major.width": 0.8,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "figure.constrained_layout.use": True,
})

# ----- RBF Kernel -----
def rbf_kernel(x1, x2, lengthscale=0.2, variance=1.0):
    diff = x1.unsqueeze(1) - x2.unsqueeze(0)
    dist_sq = (diff ** 2).sum(-1)
    return variance * torch.exp(-0.5 * dist_sq / lengthscale**2)

def generate_rbf_data(num_datasets, num_points, hyperparameters):
    lengthscale = hyperparameters.get('lengthscale', 0.2)
    kernel_variance = hyperparameters.get('kernel_variance', 0.01)
    output_noise = hyperparameters.get('output_noise', 0.01)
    mean_shift = hyperparameters.get('mean_shift', 1.05)

    xs = torch.linspace(0, 1, steps=num_points).unsqueeze(-1).repeat(num_datasets, 1, 1)
    ys = []
    for i in range(num_datasets):
        x = xs[i]
        K = rbf_kernel(x, x, lengthscale, kernel_variance)
        K += output_noise**2 * torch.eye(num_points)
        y = torch.distributions.MultivariateNormal(
            torch.full((num_points,), mean_shift), K
        ).sample()
        ys.append(y)
    ys = torch.stack(ys, dim=0)
    return xs, ys

# ----- Linear-Periodic Kernel -----
def linear_periodic_kernel(x1, x2, lengthscale=0.2, variance=1.0, period=1.0):
    diff = x1.unsqueeze(1) - x2.unsqueeze(0)
    dist_sq = (diff ** 2).sum(-1)
    sin_term = torch.sin(np.pi * torch.sqrt(dist_sq.clamp(min=1e-12)) / period) ** 2
    periodic_part = torch.exp(-2 * sin_term / lengthscale ** 2)
    linear_part = torch.matmul(x1, x2.T)
    return variance * linear_part * periodic_part

def generate_linear_periodic_data(num_datasets, num_points, hyperparameters):
    lengthscale = hyperparameters.get('lengthscale', 0.2)
    kernel_variance = hyperparameters.get('kernel_variance', 0.01)
    output_noise = hyperparameters.get('output_noise', 0.01)
    mean_shift = hyperparameters.get('mean_shift', 1.05)
    period = hyperparameters.get('period', 1.0)
    jitter = 1e-6

    xs = torch.linspace(0, 1, steps=num_points).unsqueeze(0).repeat(num_datasets, 1).unsqueeze(-1)
    ys = []
    for i in range(num_datasets):
        x = xs[i]
        K = linear_periodic_kernel(x, x, lengthscale, kernel_variance, period)
        K += (output_noise**2 + jitter) * torch.eye(num_points)
        y = torch.distributions.MultivariateNormal(
            torch.full((num_points,), mean_shift), K
        ).sample()
        ys.append(y)
    ys = torch.stack(ys, dim=0)
    return xs, ys
p
# ----- Parameters -----
num_datasets = 2
num_points = 100
hyperparameters = {
    'lengthscale': 0.6,
    'kernel_variance': 0.01,
    'output_noise': 1e-3,
    'mean_shift': 1.05,
    'period': 1.0
}

# Generate data
xs_rbf, ys_rbf = generate_rbf_data(num_datasets, num_points, hyperparameters)
xs_lp, ys_lp = generate_linear_periodic_data(num_datasets, num_points, hyperparameters)

# ----- Plot -----
fig, ax = plt.subplots(figsize=(3.25, 2.2))

colors = ["blue", "red", "black", "darkgreen"]

# Linear-Periodic Kernel (Smooth)
for i in range(num_datasets):
    ax.plot(xs_lp[i,:,0].numpy(), ys_lp[i].numpy(),
            label=f'Smooth Dataset {i+1}', linewidth=1, color=colors[i])

# RBF Kernel (Non-Smooth)
for i in range(num_datasets):
    ax.plot(xs_rbf[i,:,0].numpy(), ys_rbf[i].numpy(),
            label=f'Non-Smooth Dataset {i+1}', linestyle='--', linewidth=1, color=colors[i+num_datasets])

ax.set_title('Sample Functions from PFN Training Datasets')
ax.set_xlabel('Input')
ax.set_ylabel('Output')
ax.legend()

# Save
plt.savefig("synthetic_gp_smooth_vs_nonsmooth.png", dpi=600)
plt.savefig("synthetic_gp_smooth_vs_nonsmooth.pdf", dpi=600)
plt.savefig("synthetic_gp_smooth_vs_nonsmooth.svg", dpi=600)

plt.show()
