# -*- coding: utf-8 -*-
import torch
import matplotlib.pyplot as plt
from spline_test import B_batch, coef2curve, curve2coef, extend_grid
import numpy as np

plt.rcParams['figure.dpi'] = 600

batch = 1000
in_dim = 2
out_dim = 1

x_eval = torch.rand(batch, in_dim) * 2 - 1
y_eval = torch.exp(torch.pi * x_eval[:, 0:1] + x_eval[:, 1:2] ** 2)

k = 10
G = 30
grid = torch.linspace(-1, 1, steps=G + k)[None, :].expand(in_dim, G + k)
extended_grid = extend_grid(grid, k_extend=k)

coef = curve2coef(x_eval, y_eval.unsqueeze(1), extended_grid, k=k)

print("Shape of fitted B-spline coefficients:", coef.shape)

y_pred = coef2curve(x_eval, extended_grid, coef, k=k)

sorted_indices = torch.argsort(x_eval[:, 0])
X_test_sorted = x_eval[sorted_indices]
y_test_sorted = y_eval[sorted_indices]
predictions_sorted = y_pred.flatten()[sorted_indices]

plt.figure(figsize=(10, 6))
plt.plot(X_test_sorted[:, 0], y_test_sorted, label='Ground Truth', color='dodgerblue', linewidth=2)
plt.plot(X_test_sorted[:, 0], predictions_sorted, label='Spline Approximation', color='coral', linestyle='--', linewidth=2)
plt.legend()
plt.title('B-Spline Approximation of a Nonlinear Function')
plt.xlabel('Input Dimension 1')
plt.ylabel('Function Output')
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('spline_k{}_G{}.png'.format(k, G))
plt.show()

mse = torch.mean((y_pred[:, 0, 0] - y_eval[:, 0]) ** 2)
rmse = torch.sqrt(mse)
mae = torch.mean(torch.abs(y_pred[:, 0, 0] - y_eval[:, 0]))

print(f"Mean Squared Error (MSE): {mse:.4e}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4e}")
print(f"Mean Absolute Error (MAE): {mae:.4e}")

