# 多个KAN

from kan import *
from scipy.special import j0
import torch.nn.functional as F
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

def show(y, y_pre):
    # 创建绘图
    plt.figure(figsize=(8, 8))

    # 绘制散点图
    plt.scatter(y, y_pre, alpha=0.7, edgecolors='k', c='b', label='Predicted vs True')

    # 绘制参考线 y = x
    plt.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', label='Ideal Prediction')

    # 添加标签和标题
    plt.xlabel('True Values')
    plt.ylabel('Predicted Values')
    plt.title('True Values vs Predicted Values')
    plt.legend()

    # 显示图形
    plt.grid(True)
    plt.savefig('true_vs_predicted.png')
    plt.show()

# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (Bn=3), 5 grid intervals (grid=5).
# layer_type:
# - s_spline
# - BN
# - Fast
# - Chev
# - Wave
# - Jacobi
# - B01
# - ReLU
# - Fourier
# - Taylor
# - RBF


expressions = [
    ("x[:, 0] * x[:, 1]", [2, 5, 1]),
    # 1.f(x, y) = xy
    ("x[:, 0] / torch.where(x[:, 1] == 0, torch.ones_like(x[:, 1]), x[:, 1])", [2, 5, 1]),
    # 2.f(x, y) = x / y
    ("torch.exp(j0(20 * x[:, 0]) + x[:, 1]**2)", [2, 5, 1]),
    # 3.f(x, y) = exp(J0(20x) + y^2)
    ("torch.tanh(5 * (x[:, 0]**4 + x[:, 1]**4 + x[:, 2]**4 - 1))", [3, 5, 1]),
    # 4.f(x1, x2, x3) = tanh(...)
    ("torch.sqrt((x[:, 0] - x[:, 1])**2 + (x[:, 2] - x[:, 3])**2)", [4, 5, 1])
    # 5.f(x1, x2, x3, x4) = sqrt(...)
]

results = {}
id = 0
best_results = []  # 存储每个 expression 最佳结果的列表

for expression, width in expressions:
    id += 1
    f = lambda x: eval(expression)
    n_var = width[0]
    dataset = create_dataset(f, n_var=n_var, device=device)

    # 为每个 expression 初始化一个空字典来存储 k 的结果
    results[id] = {}

    # 初始化当前 expression 最佳结果的存储变量
    best_k = None
    best_result = None

    for k in range(1, 10):
        model = KAN(width=width, grid=3, k=k, seed=42, device=device, layer_type="BN")
        model.fit(dataset, opt="Adam", steps=1000, lr=0.001)

        # 计算真实值和预测值
        x, y_true = dataset['test_input'], dataset['test_label']  # 从数据集中获取输入数据和真实值
        y_pred = model(x)  # 获取模型的预测输出

        # 1. 计算 MSE
        mse = F.mse_loss(y_pred, y_true)

        # 2. 计算 MAE
        mae = F.l1_loss(y_pred, y_true)

        # 3. 计算 FDE
        fde = torch.sqrt((y_pred[-1] - y_true[-1]) ** 2)

        # 4. 计算 ADE
        ade = torch.mean(torch.sqrt((y_pred - y_true) ** 2))

        # 将每次计算结果存储在字典中
        results[id][k] = {
            'MSE': mse.item(),
            'MAE': mae.item(),
            'FDE': fde.item(),
            'ADE': ade.item()
        }

        # 更新当前 expression 最佳结果
        if best_result is None or mse.item() < best_result['MSE']:
            best_k = k
            best_result = results[id][k]

        # 打印每次迭代的误差结果
        print(
            f"Expression: {id}, k={k} | MSE: {mse.item():.6f}, MAE: {mae.item():.6f}, FDE: {fde.item():.6f}, ADE: {ade.item():.6f}")

    # lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
        # model.auto_symbolic(lib=lib)
        # model.fit(dataset, opt="Adma", steps=500, lr=0.001)
        # model.plot()
        # plt.savefig('plot_image.png')
        # # Obtain the symbolic formula
        # formula = model.symbolic_formula()[0][0]
        # print(formula)

    # 存储当前 expression 的最佳结果
    if best_result is not None:
        best_results.append({
            'Expression': id,
            'Best_k': best_k,
            'MSE': best_result['MSE'],
            'MAE': best_result['MAE'],
            'FDE': best_result['FDE'],
            'ADE': best_result['ADE']
        })
    else:
        best_results.append({
            'Expression': id,
            'Best_k': None,
            'MSE': None,
            'MAE': None,
            'FDE': None,
            'ADE': None
        })


# 打印所有 expression 的最佳结果
print("\nBest Results for All Expressions:")
for result in best_results:
    expr_id = result['Expression']
    print(f"Expression {expr_id}:")
    print(f"  Best k: {result['Best_k']}")
    print(f"  MSE: {result['MSE']:.6f}" if result['MSE'] is not None else "  MSE: N/A")
    print(f"  MAE: {result['MAE']:.6f}" if result['MAE'] is not None else "  MAE: N/A")
    print(f"  FDE: {result['FDE']:.6f}" if result['FDE'] is not None else "  FDE: N/A")
    print(f"  ADE: {result['ADE']:.6f}" if result['ADE'] is not None else "  ADE: N/A")
# model = KAN(width=[2, 5, 1], grid=3, k=3, seed=42, device=device, layer_type='Taylor')
#
# # 创建输入函数
# from kan.utils import create_dataset
# # create dataset f(x,y) = exp(sin(pi*x)+y^2)
# f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
# dataset = create_dataset(f, n_var=2, device=device)
# dataset['train_input'].shape, dataset['train_label'].shape
#
# # plot KAN at initialization
# model(dataset['train_input'])
# model.plot()
# # plt.savefig('plot.jpg')  # 图片存储路径
#
# # 训练模型
# # train the model
# model.fit(dataset, opt="LBFGS", steps=50, lamb=0.001)
# model.plot() # 训练的kan
#
# model.prune()
# model.plot # 再训练一次
#
#
# model.fit(dataset, opt="Adma", steps=50, lr=0.001)
# model = model.refine(10)
# model.fit(dataset, opt="Adma", steps=50, lr=0.001)
# mode = "auto" # "manual"
# '''
# manual:可以直接指定特定参数的符号表达式;
# auto:通过提供一个库lib，模型会自动尝试从库中选取合适的符号表达式来拟合数据
# '''
#
# if mode == "manual":
#     # manual mode 模式选择
#     model.fix_symbolic(0,0,0,'sin')
#     model.fix_symbolic(0,1,0,'x^2')
#     model.fix_symbolic(1,0,0,'exp')
# elif mode == "auto":
#     # 创建一个包含基本函数的库。这个库用于自动符号回归，模型会尝试用这些函数来逼近目标函数。库中的函数包括基本多项式、指数、对数、平方根、双曲正切、正弦和绝对值等
#     # automatic mode
#     lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
#     model.auto_symbolic(lib=lib)
# #
# model.fit(dataset, opt="Adma", steps=50, lr=0.01)
# model.plot()
# plt.savefig('plot_image.png')
# # Obtain the symbolic formula
# formula = model.symbolic_formula()[0][0]
# print(formula)
# from kan.utils import ex_round
# #ex_round函数，该函数用于四舍五入数值
# ex_round(model.symbolic_formula()[0][0],4)