# -*- coding: utf-8 -*-
"""
Created on Sun Sep  8 19:54:22 2024

@author: kernel
"""
import torch
import matplotlib.pyplot as plt
from spline_test import B_batch, coef2curve, curve2coef, extend_grid
from matplotlib import font_manager  
plt.rcParams['font.sans-serif'] = ['SimHei']  
plt.rcParams['axes.unicode_minus'] = False  



batch = 100
in_dim = 1  
out_dim = 1  


x_eval = torch.rand(batch, in_dim) * 2 - 1  # (100, 2)

y_eval = torch.abs(x_eval[:, 0:1])

grid = torch.linspace(-1, 1, steps=10)[None, :].expand(in_dim, 10)  # (2, 11)


k = 5 
extended_grid = extend_grid(grid, k_extend=3)  # (2, 17)


coef = curve2coef(x_eval, y_eval.unsqueeze(1), extended_grid, k=k)

# 打印拟合出的 B 样条系数
print("Fitted B-spline coefficient shape:", coef.shape) 


y_pred = coef2curve(x_eval, extended_grid, coef, k=k)

# 绘制原始数据和拟合曲线
plt.figure(figsize=(8, 6))
plt.scatter(x_eval[:, 0], y_eval, label=f"Real data", color='blue')
plt.scatter(x_eval[:, 0], y_pred[:, 0, 0].detach(), label=f"Fitting curve", color='red', marker='x')
plt.xlabel('x1')
plt.ylabel('y')
plt.legend()
plt.title(f'Comparison of B-spline fitting curve and real data')
plt.show()