import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import mean_squared_error, mean_absolute_error
plt.rcParams['figure.dpi'] = 600

class ComplexLinear(tf.keras.Model):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = tf.Variable(
            tf.random.normal([input_dim, output_dim], mean=0, stddev=0.1, dtype=tf.float32), trainable=True)
        self.real_bias = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.imag_bias = tf.Variable(tf.random.normal([output_dim], mean=-0.1, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k1 = tf.Variable(tf.random.normal([output_dim], mean=0.5, stddev=0.1, dtype=tf.float32),
                                     trainable=True)
        self.lambda_k2 = tf.Variable(tf.random.normal([output_dim], mean=1, stddev=0.1, dtype=tf.float32),
                                     trainable=True)

    def call(self, x):
        if tf.dtypes.as_dtype(x.dtype).is_complex:
            x_real = tf.math.real(x)
            x_imag = tf.math.imag(x)
        else:
            x_real = x
            x_imag = tf.zeros_like(x)

        real_1 = tf.linalg.matmul(x_real, self.real_weights) + self.real_bias
        real = tf.square(real_1)
        complex_output = real + self.imag_bias ** 2
        complex_output_1 = real_1 / complex_output
        output_1 = complex_output_1 * self.lambda_k1

        complex_output_2 = self.imag_bias / complex_output
        output_2 = complex_output_2 * self.lambda_k2

        x = output_1 + output_2
        x = tf.reduce_mean(x, -1)
        x = tf.reshape(x, [-1, 1])
        return x

def target_function(*x):
    return np.exp(np.sum(np.sin(np.pi * np.array(x) / 2) ** 2, axis=0) / 100)

def generate_data(batch_size=32, input_dim=100):
    inputs = np.random.uniform(-1.0, 1.0, (batch_size, input_dim))
    
    z_data = target_function(*[inputs[:, i] for i in range(input_dim)])
    
    return inputs, z_data

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, batch_size, input_dim=100):
        self.batch_size = batch_size
        self.input_dim = input_dim

    def __len__(self):
        return 250

    def __getitem__(self, index):
        # 
        inputs, targets = generate_data(self.batch_size, self.input_dim)
        return inputs, targets

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.losses_train = []
        self.losses_test = []
        self.X_test = X_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=None):
        self.losses_train.append(logs['loss'])

        test_loss = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        self.losses_test.append(test_loss)

        if (epoch + 1) % 50 == 0:
            print(f"Epoch {epoch + 1}, Loss (Train): {logs['loss']:.4e}, Loss (Test): {test_loss:.4e}")

def lr_schedule(epoch):
    if epoch < 100:
        return 0.001
    elif 100 <= epoch < 150:
        return 0.0002
    elif 150 <= epoch < 200:
        return 0.00005
    else:
        return 0.00001

def train_model(batch_size=32, hidden_units=5000, input_dim=100, epochs=250):
    model = tf.keras.Sequential([
        ComplexLinear(input_dim, hidden_units)
    ])
    
    model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mean_squared_error')
    
    X_test, y_test = generate_data(1000, input_dim)

    lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

    callback = CustomCallback(X_test=X_test, y_test=y_test)

    train_data_generator = DataGenerator(batch_size, input_dim)
    
    start_time = time.time()
    model.fit(train_data_generator, epochs=epochs, verbose=0, callbacks=[callback, lr_scheduler])
    elapsed_time = time.time() - start_time
    
    print(f"Training completed in {elapsed_time:.2f} seconds.")
    
    return model, callback.losses_train, callback.losses_test

model, train_losses, test_losses = train_model()

X_test, y_test = generate_data(1000, input_dim=100)
predictions = model.predict(X_test)

Z_real = y_test
Z_pred = predictions[:, 0]
Z_diff = Z_real - Z_pred

mse = np.mean((Z_real - Z_pred) ** 2)
rmse = np.sqrt(mse)
mae = np.mean(np.abs(Z_real - Z_pred))
print(f"Mean Squared Error (MSE): {mse:.4e}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4e}")
print(f"Mean Absolute Error (MAE): {mae:.4e}")

# ---------------------------------------------------------
plt.figure(figsize=(10, 6))
plt.plot(np.sqrt(train_losses), label='Train', alpha=0.7)
plt.plot(np.sqrt(test_losses), label='Test', alpha=0.7)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('RMSE', fontsize=14)
plt.yscale('log')
plt.legend(fontsize=15)
plt.grid(True, linestyle='--', alpha=0.7)
plt.title('XNet with 5000 basis functions')
plt.show()