from kan import *



torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (Bn=3), 5 grid intervals (grid=5).
model = KAN(width=[2,1], grid=3, k=3, seed=42, device=device)

# 创建输入函数
from kan.utils import create_dataset
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape

# plot KAN at initialization
model(dataset['train_input'])

model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
# model.plot() # 训练的kan

model.prune()
# model.plot # 再训练一次


model.fit(dataset, opt="Adma", steps=50, lr=0.001)
model = model.refine(10)
model.fit(dataset, opt="Adma", steps=50, lr=0.001)
mode = "auto" # "manual"
'''
manual:可以直接指定特定参数的符号表达式;
auto:通过提供一个库lib，模型会自动尝试从库中选取合适的符号表达式来拟合数据
'''

if mode == "manual":
    # manual mode 模式选择
    model.fix_symbolic(0,0,0,'sin')
    model.fix_symbolic(0,1,0,'x^2')
    model.fix_symbolic(1,0,0,'exp')
elif mode == "auto":
    # 创建一个包含基本函数的库。这个库用于自动符号回归，模型会尝试用这些函数来逼近目标函数。库中的函数包括基本多项式、指数、对数、平方根、双曲正切、正弦和绝对值等
    # automatic mode
    lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
    model.auto_symbolic(lib=lib)
#
model.fit(dataset, opt="Adma", steps=50, lr=0.01)
model.plot()
# Obtain the symbolic formula
formula = model.symbolic_formula()[0][0]
print(formula)
from kan.utils import ex_round
#ex_round函数，该函数用于四舍五入数值
ex_round(model.symbolic_formula()[0][0],4)