import numpy as np
import torch
import math
from sklearn.model_selection import train_test_split
from torch.optim import Adam
import matplotlib.pyplot as plt
import time  # 用于记录训练时间

# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义目标函数
def target_function(x, y):
    return np.exp(np.sin(np.pi * x + y**2))

# 生成随机输入数据
num_samples = 1000
x_data = np.random.uniform(-1, 1, num_samples)
y_data = np.random.uniform(-1, 1, num_samples)

# 计算目标输出
z_data = target_function(x_data, y_data)

# 将 (x, y) 数据合并作为输入
inputs = np.vstack([x_data, y_data]).T
outputs = z_data

# 将数据分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(inputs, outputs, test_size=0.2, random_state=42)

# 转换为 PyTorch 张量
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)  # 添加维度 [N, 1]
y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

# 定义 KANLinear 类
class KANLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = 2 / grid_size  # grid_range = [-1, 1]
        grid = (torch.arange(-spline_order, grid_size + spline_order + 1) * h - 1).expand(in_features, -1).contiguous()
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
        self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features))

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x):
        x = x.unsqueeze(-1)
        bases = ((x >= self.grid[:, :-1]) & (x < self.grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = ((x - self.grid[:, : -(k + 1)]) / (self.grid[:, k:-1] - self.grid[:, : -(k + 1)]) * bases[:, :, :-1]
                     + (self.grid[:, k + 1 :] - x) / (self.grid[:, k + 1 :] - self.grid[:, 1:(-k)]) * bases[:, :, 1:])
        return bases.contiguous()

    def forward(self, x):
        base_output = torch.nn.functional.linear(x, self.base_weight)
        spline_output = torch.nn.functional.linear(self.b_splines(x).view(x.size(0), -1),
                                                   self.spline_weight.view(self.out_features, -1))
        return base_output + spline_output

# 定义 KAN 类
class KAN(torch.nn.Module):
    def __init__(self, input_size, grid_size=5, spline_order=3, hidden_size=64, n_hidden_layers=1):
        super(KAN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        layers = [KANLinear(input_size, hidden_size, grid_size=grid_size, spline_order=spline_order)]
        for _ in range(n_hidden_layers - 1):
            layers.append(KANLinear(hidden_size, hidden_size, grid_size=grid_size, spline_order=spline_order))
        layers.append(KANLinear(hidden_size, 1))  # 输出一个标量，表示函数值

        self.layers = torch.nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# 实例化 KAN 模型
input_size = 2  # 输入的维度 (x, y)
model = KAN(input_size=input_size).to(device)

# 定义优化器
optimizer = Adam(model.parameters(), lr=1e-3)

# 定义损失函数
loss_fn = torch.nn.MSELoss()

# 将数据移动到 GPU（如果可用）
X_train, X_test = X_train.to(device), X_test.to(device)
y_train, y_test = y_train.to(device), y_test.to(device)

# 记录训练开始时间
start_time = time.time()

# 训练模型
epochs = 5000
train_losses = []
test_losses = []

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()

    # 前向传播
    y_pred = model(X_train)

    # 计算损失
    loss = loss_fn(y_pred, y_train)

    # 反向传播和优化
    loss.backward()
    optimizer.step()

    # 记录训练损失
    train_losses.append(loss.item())

    # 在测试集上评估
    model.eval()
    with torch.no_grad():
        test_pred = model(X_test)
        test_loss = loss_fn(test_pred, y_test)
        test_losses.append(test_loss.item())

    if epoch % 100 == 0:
        print(f'Epoch [{epoch}/{epochs}], Train Loss: {loss.item():.6f}, Test Loss: {test_loss.item():.6f}')

# 记录训练结束时间
elapsed_time = time.time() - start_time

# 绘制训练和测试损失曲线
plt.figure(figsize=(10, 6))
plt.plot(test_losses, label='Test Loss', alpha=0.7)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.yscale('log')
plt.yticks([1, 0.1, 0.01, 0.001, 0.0001], ['$10^0$', '$10^{-1}$', '$10^{-2}$', '$10^{-3}$', '$10^{-4}$'])
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.title('Test Loss over Epochs')
plt.show()



# 在测试集上进行预测
model.eval()
with torch.no_grad():
    predictions = model(X_test)

# 转换为 NumPy 数组
predictions = predictions.cpu().numpy()
y_test_np = y_test.cpu().numpy()

# 计算误差指标
mse = np.mean((predictions - y_test_np) ** 2)
rmse = np.sqrt(mse)
mae = np.mean(np.abs(predictions - y_test_np))

# 输出误差指标和训练时间
print(f"Mean Squared Error (MSE): {mse:.4e}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4e}")
print(f"Mean Absolute Error (MAE): {mae:.4e}")
print(f"Training time: {elapsed_time:.4f} seconds")

# 绘制预测结果与真实函数值的对比
plt.scatter(X_test[:, 0].cpu(), X_test[:, 1].cpu(), c=y_test_np, cmap='viridis')
plt.title('True Function Values')
plt.colorbar()
plt.show()

plt.scatter(X_test[:, 0].cpu(), X_test[:, 1].cpu(), c=predictions.flatten(), cmap='viridis')
plt.title('KAN Predictions')
plt.colorbar()
plt.show()