import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split
from scipy.special import sph_harm
plt.rcParams['figure.dpi'] = 600

class ComplexLinear(tf.keras.Model):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = tf.Variable(
            tf.random.normal([input_dim, output_dim], mean=0, stddev=0.01, dtype=tf.float32), trainable=True)
        self.real_bias = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=1, dtype=tf.float32),
                                     trainable=True)
        self.imag_bias = tf.Variable(tf.random.normal([output_dim], mean=-0.001, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k1 = tf.Variable(tf.random.normal([output_dim], mean=0.1, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k2 = tf.Variable(tf.random.normal([output_dim], mean=0.1, stddev=0.1, dtype=tf.float32),
                                     trainable=True)

    def call(self, x):
        if tf.dtypes.as_dtype(x.dtype).is_complex:
            x_real = tf.math.real(x)
            x_imag = tf.math.imag(x)
        else:
            x_real = x
            x_imag = tf.zeros_like(x)

        real_1 = tf.linalg.matmul(x_real, self.real_weights) + self.real_bias
        real = tf.square(real_1)
        complex_output = real + self.imag_bias ** 2
        complex_output_1 = real_1 / complex_output
        output_1 = complex_output_1 * self.lambda_k1

        complex_output_2 = self.imag_bias / complex_output
        output_2 = complex_output_2 * self.lambda_k2

        x = output_1 + output_2
        x = tf.reduce_mean(x, -1)
        x = tf.reshape(x, [-1, 1])
        return x

def target_function(x, y):
    return sph_harm(1, 2, x, y).real


num_samples = 10000

num_internal_samples = int(0.9 * num_samples)
num_boundary_samples = num_samples - num_internal_samples
x_internal = np.random.uniform(-1, 1, num_internal_samples)
y_internal = np.random.uniform(0, 0.5, num_internal_samples)
per_edge = num_boundary_samples // 4
x_left   = np.full(per_edge, -1.0)
y_left   = np.random.uniform(0, 0.5, per_edge)
x_right  = np.full(per_edge, 1.0)
y_right  = np.random.uniform(0, 0.5, per_edge)
y_bottom = np.full(per_edge, 0.0)
x_bottom = np.random.uniform(-1, 1, per_edge)
y_top    = np.full(per_edge, 0.5)
x_top    = np.random.uniform(-1, 1, per_edge)
x_boundary = np.concatenate([x_left, x_right, x_bottom, x_top])
y_boundary = np.concatenate([y_left, y_right, y_bottom, y_top])
remaining = num_boundary_samples - len(x_boundary)
if remaining > 0:
    x_extra = np.random.uniform(-1, 1, remaining)
    y_extra = np.random.choice([0.0, 0.5], remaining)  # 补在 y 边界
    x_boundary = np.concatenate([x_boundary, x_extra])
    y_boundary = np.concatenate([y_boundary, y_extra])
x_data = np.concatenate([x_internal, x_boundary])
y_data = np.concatenate([y_internal, y_boundary])


z_data = target_function(x_data, y_data)

inputs = np.hstack([x_data.reshape(-1, 1), y_data.reshape(-1, 1)])

X_train, X_test, y_train, y_test = train_test_split(inputs, z_data, test_size=0.5, random_state=42)

input_dim = 2
hidden_units = 500
model = tf.keras.Sequential([
    ComplexLinear(input_dim, hidden_units)
])

def lr_schedule(epoch):
    if epoch < 100:
        return 0.005
    elif 100 <= epoch < 300:
        return 0.001
    elif 300 <= epoch < 600:
        return 0.0002
    elif 600 <= epoch < 800:
        return 0.00005
    elif 800 <= epoch < 900:
        return 0.00005
    else:
        return 0.00001

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mean_squared_error')

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.losses_train = []
        self.losses_test = []
        self.X_test = X_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=None):
        self.losses_train.append(logs['loss'])
        test_loss = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        self.losses_test.append(test_loss)
        if (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch + 1}, Loss (Train): {logs['loss']:.4e}, Loss (Test): {test_loss:.4e}")

callback = CustomCallback(X_test, y_test)
start_time = time.time()
model.fit(X_train, y_train, epochs=1200, batch_size=32, verbose=0, callbacks=[callback, lr_scheduler])
elapsed_time = time.time() - start_time

predictions = model.predict(X_test)

x_lin = np.linspace(-1, 1, 100)
y_lin = np.linspace(0, 0.5, 100)
X_grid, Y_grid = np.meshgrid(x_lin, y_lin)

grid_points = np.hstack([X_grid.reshape(-1, 1), Y_grid.reshape(-1, 1)])

Z_real = target_function(grid_points[:, 0], grid_points[:, 1]).reshape(100, 100)

Z_pred = model.predict(grid_points).reshape(100, 100)

fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X_grid, Y_grid, Z_real, cmap='viridis', alpha=0.8)
ax.set_title('True Function', fontsize=15)
ax.set_xlabel('X', fontsize=12)
ax.set_ylabel('Y', fontsize=12)
ax.set_zlabel('Z', fontsize=12)
plt.show()

fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X_grid, Y_grid, Z_pred, cmap='plasma', alpha=0.8)
ax.set_title('Prediction on Grid', fontsize=15)
ax.set_xlabel('X', fontsize=12)
ax.set_ylabel('Y', fontsize=12)
ax.set_zlabel('Z', fontsize=12)
plt.show()

Z_diff = Z_real - Z_pred
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
diff_surface = ax.plot_surface(X_grid, Y_grid, Z_diff, cmap='coolwarm', alpha=0.8)
ax.set_title('Prediction Difference', fontsize=15)
ax.set_xlabel('x', fontsize=12)
ax.set_ylabel('y', fontsize=12)
ax.set_zlabel('Difference', fontsize=12)
ax.zaxis.labelpad = 7
plt.colorbar(diff_surface, ax=ax, shrink=0.5, aspect=5)
plt.show()

def smooth_curve(values, window_size=10):
    return np.convolve(values, np.ones(window_size)/window_size, mode='valid')

smooth_train_loss = smooth_curve(np.sqrt(callback.losses_train), window_size=5)
smooth_test_loss = smooth_curve(np.sqrt(callback.losses_test), window_size=5)

plt.figure(figsize=(10, 6))
plt.plot(smooth_train_loss, label='Train', alpha=0.7)
plt.plot(smooth_test_loss, label='Test', alpha=0.7)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('RMSE', fontsize=14)
plt.yscale('log')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.title('Training Curve')
plt.show()

mse = mean_squared_error(Z_real, Z_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(Z_real, Z_pred)

print(f"{mse:.4e}")
print(f"{rmse:.4e}")
print(f"{mae:.4e}")
print(f"{elapsed_time:.4f}")