# 导入必要的库
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split
plt.rcParams['figure.dpi'] = 600

class ComplexLinear(tf.keras.Model):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = tf.Variable(
            tf.random.normal([input_dim, output_dim], mean=0, stddev=0.1, dtype=tf.float32), trainable=True)
        self.real_bias = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.imag_bias = tf.Variable(tf.random.normal([output_dim], mean=-0.1, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k1 = tf.Variable(tf.random.normal([output_dim], mean=0.2, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k2 = tf.Variable(tf.random.normal([output_dim], mean=0.5, stddev=0.1, dtype=tf.float32),
                                     trainable=True)

    def call(self, x):
        if tf.dtypes.as_dtype(x.dtype).is_complex:
            x_real = tf.math.real(x)
            x_imag = tf.math.imag(x)
        else:
            x_real = x
            x_imag = tf.zeros_like(x)

        real_1 = tf.linalg.matmul(x_real, self.real_weights) + self.real_bias
        real = tf.square(real_1)
        complex_output = real + self.imag_bias ** 2
        complex_output_1 = real_1 / complex_output
        output_1 = complex_output_1 * self.lambda_k1

        complex_output_2 = self.imag_bias / complex_output
        output_2 = complex_output_2 * self.lambda_k2

        x = output_1 + output_2
        x = tf.reduce_mean(x, -1)
        x = tf.reshape(x, [-1, 1])
        return x


def target_function(x, y):
    return x * y


num_samples = 2000


num_internal_samples = int(0.9 * num_samples) 
num_boundary_samples = num_samples - num_internal_samples  
x_internal = np.random.uniform(-1, 1, num_internal_samples)  
y_internal = np.random.uniform(-1, 1, num_internal_samples) 
x_boundary = np.random.choice([-1, 1], num_boundary_samples)  
y_boundary = np.random.uniform(-1, 1, num_boundary_samples)  
x_boundary = np.concatenate((x_boundary, np.random.uniform(-1, 1, num_boundary_samples)))  
y_boundary = np.concatenate((y_boundary, np.random.choice([-1, 1], num_boundary_samples)))  
x_data = np.concatenate((x_internal, x_boundary))
y_data = np.concatenate((y_internal, y_boundary))


z_data = target_function(x_data, y_data)


inputs = np.hstack([x_data.reshape(-1, 1), y_data.reshape(-1, 1)])


X_train, X_test, y_train, y_test = train_test_split(inputs, z_data, test_size=0.5, random_state=42)


input_dim = 2  
hidden_units = 5000 
model = tf.keras.Sequential([
    ComplexLinear(input_dim, hidden_units)

])


def lr_schedule(epoch):
    if epoch < 300:
        return 0.001  
    elif 300 <= epoch < 600:
        return 0.0001  
    elif 600 <= epoch < 800:
        return 0.00001  
    else:
        return 0.000001  


lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)


model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mean_squared_error')


class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.losses_train = []
        self.losses_test = []
        self.X_test = X_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=None):

        self.losses_train.append(logs['loss'])


        test_loss = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        self.losses_test.append(test_loss)


        if (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch + 1}, Loss (Train): {logs['loss']:.4e}, Loss (Test): {test_loss:.4e}")


callback = CustomCallback(X_test, y_test)
start_time = time.time()
model.fit(X_train, y_train, epochs=1000, batch_size=32, verbose=0, callbacks=[callback, lr_scheduler])
elapsed_time = time.time() - start_time


predictions = model.predict(X_test)



# ---------------------------------------------------------

x_lin = np.linspace(-1, 1, 100)
y_lin = np.linspace(-1, 1, 100)
X_grid, Y_grid = np.meshgrid(x_lin, y_lin)


grid_points = np.hstack([X_grid.reshape(-1, 1), Y_grid.reshape(-1, 1)])


Z_real = target_function(grid_points[:, 0], grid_points[:, 1]).reshape(100, 100)


Z_pred = model.predict(grid_points).reshape(100, 100)

# ---------------------------------------------------------

fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X_grid, Y_grid, Z_real, cmap='viridis', alpha=0.8)
ax.set_title('True Function: xy', fontsize=15)
ax.set_xlabel('X', fontsize=12)
ax.set_ylabel('Y', fontsize=12)
ax.set_zlabel('Z', fontsize=12)
plt.show()

# 绘制预测解的三维表面图
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X_grid, Y_grid, Z_pred, cmap='plasma', alpha=0.8)
ax.set_title('XNet Prediction on Grid', fontsize=15)
ax.set_xlabel('X', fontsize=12)
ax.set_ylabel('Y', fontsize=12)
ax.set_zlabel('Z', fontsize=12)
# plt.savefig(r'Xnet_Prediction.png')
plt.show()

# 绘制真实值和预测值的差异图
Z_diff = Z_real - Z_pred
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
diff_surface = ax.plot_surface(X_grid, Y_grid, Z_diff, cmap='coolwarm', alpha=0.8)
ax.set_title('XNet with 5000 basis functions', fontsize=15)
ax.set_xlabel('x', fontsize=12)
ax.set_ylabel('y', fontsize=12)
ax.set_zlabel('Difference', fontsize=12)
ax.zaxis.labelpad = 7  
plt.colorbar(diff_surface, ax=ax, shrink=0.5, aspect=5)  
plt.show()



def smooth_curve(values, window_size=10):
    return np.convolve(values, np.ones(window_size)/window_size, mode='valid')

smooth_train_loss = smooth_curve(np.sqrt(callback.losses_train), window_size=5)
smooth_test_loss = smooth_curve(np.sqrt(callback.losses_test), window_size=5)


plt.figure(figsize=(10, 6))
plt.plot(smooth_train_loss, label='Train', alpha=0.7)
plt.plot(smooth_test_loss, label='Test', alpha=0.7)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('RMSE', fontsize=14)
plt.yscale('log')
# plt.yticks([1, 0.1, 0.01, 0.001, 0.0001, 0.00001], ['$10^0$', '$10^{-1}$', '$10^{-2}$', '$10^{-3}$', '$10^{-4}$', '$10^{-5}$'])
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.title('XNet with 5000 basis functions')
plt.show()

# plt.figure(figsize=(10, 6))
# # smooth_loss_train = np.convolve(np.sqrt(callback.losses_train), np.ones(1)/1)
# # smooth_loss_test = np.convolve(np.sqrt(callback.losses_test), np.ones(1)/1)
# plt.plot(np.sqrt(callback.losses_train), label='Train', alpha=0.7)
# plt.plot(np.sqrt(callback.losses_test), label='Test', alpha=0.7)
# # plt.plot(smooth_loss_train, label='Train', alpha=0.7)
# # plt.plot(smooth_loss_test, label='Test', alpha=0.7)
# plt.xlabel('Epochs', fontsize=14)
# plt.ylabel('RMSE', fontsize=14)
# plt.yscale('log')
# plt.legend()
# plt.grid(True, linestyle='--', alpha=0.7)
# plt.title('Training and Test Loss over Epochs')
# plt.show()

# 计算误差
mse = mean_squared_error(Z_real, Z_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(Z_real, Z_pred)

print(f"Mean Squared Error (MSE): {mse:.4e}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4e}")
print(f"Mean Absolute Error (MAE): {mae:.4e}")
print(f"Training time: {elapsed_time:.4f}")