# Import required libraries
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split

# Define the core module of XNet: ComplexLinear layer
class ComplexLinear(tf.keras.Model):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = tf.Variable(
            tf.random.normal([input_dim, output_dim], mean=0, stddev=0.1, dtype=tf.float32), trainable=True)
        self.real_bias = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.imag_bias = tf.Variable(tf.random.normal([output_dim], mean=-0.1, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k1 = tf.Variable(tf.random.normal([output_dim], mean=0.5, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k2 = tf.Variable(tf.random.normal([output_dim], mean=1, stddev=0.1, dtype=tf.float32),
                                     trainable=True)

    def call(self, x):
        if tf.dtypes.as_dtype(x.dtype).is_complex:
            x_real = tf.math.real(x)
            x_imag = tf.math.imag(x)
        else:
            x_real = x
            x_imag = tf.zeros_like(x)

        real_1 = tf.linalg.matmul(x_real, self.real_weights) + self.real_bias
        real = tf.square(real_1)
        complex_output = real + self.imag_bias ** 2
        complex_output_1 = real_1 / complex_output
        output_1 = complex_output_1 * self.lambda_k1

        complex_output_2 = self.imag_bias / complex_output
        output_2 = complex_output_2 * self.lambda_k2

        x = output_1 + output_2
        x = tf.reduce_mean(x, -1)
        x = tf.reshape(x, [-1, 1])
        return x

# Define the target function: sin(pi(x1^2 + x2^2)) + x3 * x4
def target_function(x_1, x_2, x_3, x_4):
    return np.sin(np.pi * (x_1**2 + x_2**2)) + x_3 * x_4

# Generate 4000 random samples uniformly in [-1, 1] for each input dimension
num_samples = 4000
x_1_data = np.random.uniform(-1.0, 1.0, num_samples)
x_2_data = np.random.uniform(-1.0, 1.0, num_samples)
x_3_data = np.random.uniform(-1.0, 1.0, num_samples)
x_4_data = np.random.uniform(-1.0, 1.0, num_samples)

# Compute target function values
z_data = target_function(x_1_data, x_2_data, x_3_data, x_4_data)

# Concatenate input features into a single input matrix
inputs = np.hstack([x_1_data.reshape(-1, 1), x_2_data.reshape(-1, 1),
                    x_3_data.reshape(-1, 1), x_4_data.reshape(-1, 1)])

# Split the dataset into training and test sets (75% training, 25% test)
X_train, X_test, y_train, y_test = train_test_split(inputs, z_data, test_size=0.25, random_state=42)

# Define the XNet architecture
input_dim = 4
hidden_units = 5000
model = tf.keras.Sequential([
    ComplexLinear(input_dim, hidden_units)
    # Output layer can be added if needed: tf.keras.layers.Dense(1)
])

# Define learning rate schedule
def lr_schedule(epoch):
    if epoch < 300:
        return 0.001
    elif 300 <= epoch < 600:
        return 0.0001
    elif 600 <= epoch < 800:
        return 0.00001
    else:
        return 0.000001
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mean_squared_error')

# Custom callback to record training and test loss at each epoch
class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.losses_train = []
        self.losses_test = []
        self.X_test = X_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=None):
        self.losses_train.append(logs['loss'])

        test_loss = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        self.losses_test.append(test_loss)

        if (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch + 1}, Loss (Train): {logs['loss']:.4e}, Loss (Test): {test_loss:.4e}")

# Train the model
callback = CustomCallback(X_test, y_test)
start_time = time.time()
model.fit(X_train, y_train, epochs=1000, batch_size=32, verbose=0, callbacks=[callback, lr_scheduler])
elapsed_time = time.time() - start_time

# Predict on the test set
predictions = model.predict(X_test)

# Compute evaluation metrics
Z_real = y_test
Z_pred = predictions[:, 0]
Z_diff = Z_real - Z_pred

mse = np.mean((Z_real - Z_pred) ** 2)
rmse = np.sqrt(mse)
mae = np.mean(np.abs(Z_real - Z_pred))
print(f"Mean Squared Error (MSE): {mse:.4e}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4e}")
print(f"Mean Absolute Error (MAE): {mae:.4e}")

# ---------------------------------------------------------
# Plot smoothed loss curves for training and test sets
def smooth_curve(values, window_size=10):
    return np.convolve(values, np.ones(window_size) / window_size, mode='valid')

smooth_train_loss = smooth_curve(np.sqrt(callback.losses_train), window_size=10)
smooth_test_loss = smooth_curve(np.sqrt(callback.losses_test), window_size=10)

plt.figure(figsize=(10, 6))
plt.plot(smooth_train_loss, label='Train', alpha=0.7)
plt.plot(smooth_test_loss, label='Test', alpha=0.7)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('RMSE', fontsize=14)
plt.yscale('log')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.title('Training and Test Loss over Epochs (Smoothed)')
plt.show()

print(f"Training time: {elapsed_time:.4f} seconds")

# Evaluate model on a new set of random test samples
num_test = 1000
x_1_data = np.random.uniform(-1.0, 1.0, num_test)
x_2_data = np.random.uniform(-1.0, 1.0, num_test)
x_3_data = np.random.uniform(-1.0, 1.0, num_test)
x_4_data = np.random.uniform(-1.0, 1.0, num_test)
z_data = target_function(x_1_data, x_2_data, x_3_data, x_4_data)
inputs = np.hstack([x_1_data.reshape(-1, 1), x_2_data.reshape(-1, 1),
                    x_3_data.reshape(-1, 1), x_4_data.reshape(-1, 1)])
z_pred = model.predict(inputs)[:, 0]

mse = np.mean((z_data - z_pred) ** 2)
rmse = np.sqrt(mse)
mae = np.mean(np.abs(z_data - z_pred))
print(f"[Final Evaluation] Mean Squared Error (MSE): {mse:.4e}")
print(f"[Final Evaluation] Root Mean Squared Error (RMSE): {rmse:.4e}")
print(f"[Final Evaluation] Mean Absolute Error (MAE): {mae:.4e}")