# 导入必要的库
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split
plt.rcParams['figure.dpi'] = 600


class ComplexLinear(tf.keras.Model):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = tf.Variable(
            tf.random.normal([input_dim, output_dim], mean=0.0, stddev=0.1, dtype=tf.float32), trainable=True)
        self.real_bias = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.01, dtype=tf.float32),
                                     trainable=True)
        self.imag_bias = tf.Variable(tf.random.normal([output_dim], mean=-0.1, stddev=0.01, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k1 = tf.Variable(tf.random.normal([output_dim], mean=0.5, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k2 = tf.Variable(tf.random.normal([output_dim], mean=0.5, stddev=0.1, dtype=tf.float32),
                                     trainable=True)

    def call(self, x):
        if tf.dtypes.as_dtype(x.dtype).is_complex:
            x_real = tf.math.real(x)
            x_imag = tf.math.imag(x)
        else:
            x_real = x
            x_imag = tf.zeros_like(x)

        real_1 = tf.linalg.matmul(x_real, self.real_weights) + self.real_bias
        real = tf.square(real_1)
        complex_output = real + self.imag_bias ** 2
        complex_output_1 = real_1 / complex_output
        output_1 = complex_output_1 * self.lambda_k1

        complex_output_2 = self.imag_bias / complex_output
        output_2 = complex_output_2 * self.lambda_k2

        x = output_1 + output_2
        x = tf.reduce_mean(x, -1)
        x = tf.reshape(x, [-1, 1])
        return x


# 目标函数 abs(x)
def target_function(x):
    return np.where(x < 0, 0, 1)



num_samples = 2000
x_data = np.random.uniform(-1, 1, num_samples)  


z_data = target_function(x_data)


inputs = x_data.reshape(-1, 1)


X_train, X_test, y_train, y_test = train_test_split(inputs, z_data, test_size=0.5, random_state=42)


input_dim = 1  
hidden_units = 5  
model = tf.keras.Sequential([
    ComplexLinear(input_dim, hidden_units),  
    tf.keras.layers.Dense(1)  
])


model.compile(optimizer='adam', loss='mean_squared_error')



class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.losses_train = []
        self.losses_test = []
        self.X_test = X_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=None):

        self.losses_train.append(logs['loss'])


        test_loss = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        self.losses_test.append(test_loss)


        if (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch + 1}, Loss (Train): {logs['loss']:.4e}, Loss (Test): {test_loss:.4e}")



callback = CustomCallback(X_test, y_test)
start_time = time.time()
model.fit(X_train, y_train, epochs=1000, batch_size=32, verbose=0, callbacks=[callback])
elapsed_time = time.time() - start_time


predictions = model.predict(X_test)
predictions_train = model.predict(X_train)

mse = mean_squared_error(y_test, predictions)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test, predictions)

print(f"Mean Squared Error (MSE): {mse:.4e}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4e}")
print(f"Mean Absolute Error (MAE): {mae:.4e}")
print(f"Training time: {elapsed_time:.4f}")


sorted_indices = np.argsort(X_test[:, 0])
X_test_sorted = X_test[sorted_indices]
y_test_sorted = y_test[sorted_indices]
predictions_sorted = predictions.flatten()[sorted_indices]

sorted_indices_train = np.argsort(X_train[:, 0])
X_train_sorted = X_train[sorted_indices_train]
y_train_sorted = y_train[sorted_indices_train]
predictions_sorted_train = predictions_train.flatten()[sorted_indices_train]

plt.figure(figsize=(10, 6))
plt.plot(X_test_sorted[:, 0], y_test_sorted, label='Original Data', color='dodgerblue', linewidth=2)
plt.plot(X_test_sorted[:, 0], predictions_sorted, label='Prediction Data', color='coral', linestyle='--', linewidth=2)

# plt.plot(X_train_sorted[:, 0], y_train_sorted, label='Original Data', color='dodgerblue', linewidth=2)
# plt.plot(X_train_sorted[:, 0], predictions_sorted_train, label='Prediction Data', color='coral', linestyle='--', linewidth=2)

plt.legend()
plt.title('XNet Prediction on Heaviside function')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.grid(True, linestyle='--', alpha=0.7)
plt.savefig(f'.../pykan-master/spline_test/figure/XNet1.png')
plt.show()

# 绘制训练集和测试集的损失曲线
smooth_loss = np.convolve(callback.losses_train, np.ones(1000)/100, mode='valid')
plt.figure(figsize=(10, 6))
# plt.plot(callback.losses_test, label='Test Loss', alpha=0.7)
plt.plot(smooth_loss, label='Test Loss', alpha=0.7)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.yscale('log')
plt.yticks([1, 0.1, 0.01, 0.001, 0.0001], ['$10^0$', '$10^{-1}$', '$10^{-2}$', '$10^{-3}$', '$10^{-4}$'])
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.title('Test Loss over Epochs')
plt.show()