import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # For 3D plotting
from spline_test import B_batch, coef2curve, curve2coef, extend_grid
import numpy as np

# Assume we have 1000 input samples in a 2D space
batch = 1000
in_dim = 1  # Input dimension (e.g., x1)
out_dim = 1  # Output dimension (e.g., y)

# Randomly generate 1000 input points in [-1, 1]
x_eval = torch.rand(batch, in_dim) * 2 - 1  # Shape: (1000, 1)

# Define the target function, e.g., a Heaviside-like function
y_eval = (x_eval[:, 0:1] >= 0).float()

# Define knot positions for B-spline basis
k = 3  # Cubic B-spline
G = 200  # Number of grid intervals
grid = torch.linspace(-1, 1, steps=G + k)[None, :].expand(in_dim, G + k)  # Shape: (in_dim, G+k)

# Extend the grid at both boundaries by k points
extended_grid = extend_grid(grid, k_extend=k)

# Compute B-spline coefficients by fitting the input-output data
coef = curve2coef(x_eval, y_eval.unsqueeze(1), extended_grid, k=k)

# Print the shape of the fitted coefficients
print("Shape of fitted B-spline coefficients:", coef.shape)  # Expected: (in_dim, out_dim, G+k)

# Evaluate the B-spline curve at the input points using the fitted coefficients
y_pred = coef2curve(x_eval, extended_grid, coef, k=k)

# Convert tensors to NumPy arrays for visualization
x_eval_np = x_eval.numpy()
y_eval_np = y_eval.numpy().flatten()
y_pred_np = y_pred.detach().numpy().flatten()

# Plot 3D scatter plot of ground truth and predictions
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')

# Plot original data
ax.scatter(x_eval_np[:, 0], y_eval_np, label='Original Data', color='dodgerblue', alpha=0.6)

# Plot predicted data
ax.scatter(x_eval_np[:, 0], y_pred_np, label='Prediction Data', color='coral', alpha=0.6)

# Configure plot labels and title
ax.set_title('B-Spline on Heaviside-like Function', fontsize=15)
ax.set_xlabel('$x_1$', fontsize=12)
ax.set_ylabel('Ground Truth / Prediction', fontsize=12)
ax.set_zlabel('$y$', fontsize=12)

# Show legend
ax.legend()

plt.show()

# Compute differences between ground truth and predictions
diff = y_eval_np - y_pred_np

# Plot 3D scatter plot of the differences
fig_diff = plt.figure(figsize=(12, 8))
ax_diff = fig_diff.add_subplot(111, projection='3d')

# Plot difference
ax_diff.scatter(x_eval_np[:, 0], np.zeros_like(x_eval_np[:, 0]), diff,
                label='Difference (True - Predicted)', color='purple', alpha=0.6)

# Set plot labels and title
ax_diff.set_title('Difference between True and Predicted Values', fontsize=15)
ax_diff.set_xlabel('$x_1$', fontsize=12)
ax_diff.set_ylabel('$x_2$ (dummy)', fontsize=12)
ax_diff.set_zlabel('Difference', fontsize=12)

# Show legend
ax_diff.legend()

plt.show()

# Compute and print evaluation metrics
mse = torch.mean((y_pred[:, 0, 0] - y_eval[:, 0]) ** 2)
rmse = torch.sqrt(mse)
mae = torch.mean(torch.abs(y_pred[:, 0, 0] - y_eval[:, 0]))

print(f"Mean Squared Error (MSE): {mse:.4e}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4e}")
print(f"Mean Absolute Error (MAE): {mae:.4e}")