# 导入必要的库
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from scipy.optimize import minimize


# 自定义 ComplexLinear 类 (XNet 核心)
class ComplexLinear(tf.keras.Model):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = tf.Variable(
            tf.random.normal([input_dim, output_dim], mean=0.0, stddev=0.1, dtype=tf.float32), trainable=True)
        self.real_bias = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.01, dtype=tf.float32),
                                     trainable=True)
        self.imag_bias = tf.Variable(tf.random.normal([output_dim], mean=-0.1, stddev=0.01, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k1 = tf.Variable(tf.random.normal([output_dim], mean=0.5, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k2 = tf.Variable(tf.random.normal([output_dim], mean=0.5, stddev=0.1, dtype=tf.float32),
                                     trainable=True)

    def call(self, x):
        if tf.dtypes.as_dtype(x.dtype).is_complex:
            x_real = tf.math.real(x)
            x_imag = tf.math.imag(x)
        else:
            x_real = x
            x_imag = tf.zeros_like(x)

        real_1 = tf.linalg.matmul(x_real, self.real_weights) + self.real_bias
        real = tf.square(real_1)
        complex_output = real + self.imag_bias ** 2
        complex_output_1 = real_1 / complex_output
        output_1 = complex_output_1 * self.lambda_k1

        complex_output_2 = self.imag_bias / complex_output
        output_2 = complex_output_2 * self.lambda_k2

        x = output_1 + output_2
        x = tf.reduce_mean(x, -1)
        x = tf.reshape(x, [-1, 1])
        return x


# 生成目标函数数据 exp(sin(πx + y^2))
def target_function(x, y):
    return np.exp(np.sin(np.pi * x + y ** 2))


# 随机生成 1000 个 (x, y) 坐标
num_samples = 1000
x_data = np.random.uniform(-1, 1, num_samples)  # x 在 [-1, 1] 范围内随机生成
y_data = np.random.uniform(-1, 1, num_samples)  # y 在 [-1, 1] 范围内随机生成

# 计算目标函数值
z_data = target_function(x_data, y_data)

# 将 x 和 y 数据合并作为输入
inputs = np.hstack([x_data.reshape(-1, 1), y_data.reshape(-1, 1)])

# 使用 train_test_split 将数据划分为训练集和测试集 (80% 训练集，20% 测试集)
X_train, X_test, y_train, y_test = train_test_split(inputs, z_data, test_size=0.2, random_state=42)

# 建立 XNet 模型
input_dim = 2  # 输入有两个维度 (x, y)
hidden_units = 64  # 隐藏层神经元数
model = tf.keras.Sequential([
    ComplexLinear(input_dim, hidden_units),  # XNet 层
    tf.keras.layers.Dense(1)  # 输出层
])

# 手动初始化模型权重
model.build(input_shape=(None, input_dim))


# 损失函数：均方误差
def loss_fn(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))


# 提取模型参数
def get_weights():
    weights = []
    for var in model.trainable_variables:
        weights.append(var.numpy().astype(np.float64).ravel())  # 强制转换为 float64
    return np.concatenate(weights)


# 将优化后的权重更新到模型中
def set_weights(flat_weights):
    idx = 0
    for var in model.trainable_variables:
        shape = var.shape
        size = np.prod(shape)
        var.assign(flat_weights[idx:idx + size].reshape(shape).astype(np.float32))  # 转换为 float32
        idx += size


# 定义损失函数和梯度计算
def loss_and_grads(flat_weights):
    # 更新模型的权重
    set_weights(flat_weights)

    with tf.GradientTape() as tape:
        # 计算模型的预测值
        y_pred = model(X_train)
        # 计算损失
        loss_value = loss_fn(y_train, y_pred)

    # 计算梯度
    grads = tape.gradient(loss_value, model.trainable_variables)
    # 将梯度展平并转换为 float64
    flat_grads = np.concatenate([g.numpy().astype(np.float64).ravel() for g in grads])

    return loss_value.numpy().astype(np.float64), flat_grads


# 记录训练集和测试集损失的列表
train_losses = []
test_losses = []


# 自定义回调函数，用于在每次迭代时记录训练集和测试集的损失
def callback(weights):
    set_weights(weights)  # 更新模型权重
    # 计算训练集损失
    train_loss = loss_fn(y_train, model(X_train)).numpy()
    train_losses.append(train_loss)
    # 计算测试集损失
    test_loss = loss_fn(y_test, model(X_test)).numpy()
    test_losses.append(test_loss)
    print(f"Train Loss: {train_loss:.4e}, Test Loss: {test_loss:.4e}")


# 使用 L-BFGS-B 优化模型参数
initial_weights = get_weights()

start_time = time.time()
result = minimize(
    fun=lambda w: loss_and_grads(w),
    x0=initial_weights,
    method='L-BFGS-B',
    jac=True,
    callback=callback,  # 每次迭代调用 callback 记录损失
    options={'maxiter': 500}
)
elapsed_time = time.time() - start_time

# 设置优化后的权重
set_weights(result.x)

# 在测试集上进行预测
predictions = model.predict(X_test)

# 计算误差
mse = mean_squared_error(y_test, predictions)
print(f"Mean Squared Error (MSE): {mse:.4e}")
print(f"Training time: {elapsed_time:.4f}")

# 可视化结果 (预测 vs 原始，在测试集上)
plt.figure(figsize=(10, 6))

# 原始测试集数据
plt.subplot(1, 2, 1)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap='viridis')
plt.colorbar()
plt.title('Original Test Function')
plt.xlabel('x')
plt.ylabel('y')

# 预测的测试集数据
plt.subplot(1, 2, 2)
plt.scatter(X_test[:, 0], X_test[:, 1], c=predictions.flatten(), cmap='viridis')
plt.colorbar()
plt.title('XNet Prediction on Test Set')
plt.xlabel('x')
plt.ylabel('y')

plt.tight_layout()
plt.show()

# 绘制训练集和测试集的损失曲线
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss', alpha=0.7)
plt.plot(test_losses, label='Test Loss', alpha=0.7)
plt.xlabel('Iterations', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.yscale('log')
plt.yticks([1, 0.1, 0.01, 0.001, 0.0001], ['$10^0$', '$10^{-1}$', '$10^{-2}$', '$10^{-3}$', '$10^{-4}$'])
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.title('Training and Test Loss over Iterations')
plt.show()