# -*- coding: utf-8 -*-
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error

plt.rcParams['figure.dpi'] = 600

seed = 3
np.random.seed(seed)
tf.random.set_seed(seed)

class ComplexLinear(tf.keras.Model):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = tf.Variable(tf.random.normal([input_dim, output_dim], stddev=0.1), trainable=True)
        self.real_bias = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.01), trainable=True)
        self.imag_bias = tf.Variable(tf.random.normal([output_dim], mean=-0.01, stddev=0.1), trainable=True)
        self.lambda_k1 = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.0001), trainable=True)
        self.lambda_k2 = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.0001), trainable=True)
    
    def call(self, x):
        if tf.dtypes.as_dtype(x.dtype).is_complex:
            x_real = tf.math.real(x)
            x_imag = tf.math.imag(x)
        else:
            x_real = x
            x_imag = tf.zeros_like(x)
        real_1 = tf.linalg.matmul(x_real, self.real_weights) + self.real_bias 
        real = tf.square(real_1)
        complex_output = real + self.imag_bias**2
        complex_output_1 = real_1 / complex_output
        output_1 = complex_output_1 * self.lambda_k1
        complex_output_2 = self.imag_bias / complex_output
        output_2 = complex_output_2 * self.lambda_k2 
        x = output_1 + output_2
        return x

class Linear(tf.keras.layers.Layer):
    def __init__(self, input_dim, output_dim):
        super(Linear, self).__init__()
        self.input_dim = input_dim

    def call(self, inputs):
        reshaped_inputs = tf.reshape(inputs, (-1, self.input_dim, 1))
        mean_outputs = tf.reduce_mean(reshaped_inputs, axis=1)     
        return mean_outputs

df_raw = pd.read_csv('.csv')
data = df_raw.iloc[:, 0].replace(',', '', regex=True).astype('float32').values.reshape(-1, 1)
df = pd.DataFrame(data, columns=['Signal'])

plt.figure(figsize=(10, 6))
plt.plot(df)
plt.title('Time Series Data')
plt.xlabel('Time Steps')
plt.ylabel('Value')
plt.show()

scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(df)

train_size = int(len(scaled_data) * 0.8)
train, test = scaled_data[:train_size], scaled_data[train_size:]

def create_dataset(dataset, look_back=10):
    X, Y = [], []
    for i in range(len(dataset) - look_back):
        X.append(dataset[i:(i + look_back), 0])
        Y.append(dataset[i + look_back, 0])
    return np.array(X), np.array(Y)

look_back = 5
X_train, y_train = create_dataset(train, look_back)
X_test, y_test = create_dataset(test, look_back)

num_hiddens = [look_back, 10, 1]
model = Sequential([
    ComplexLinear(num_hiddens[0], num_hiddens[1]),
    Linear(num_hiddens[1], num_hiddens[2])
])

model.compile(optimizer='adam', loss='mean_squared_error')

def lr_schedule(epoch):
    if epoch < 50:
        return 1e-3
    elif epoch < 200:
        return 5e-4
    elif epoch < 400:
        return 1e-5
    else:
        return 5e-6

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.losses_train = []
        self.losses_test = []
        self.X_test = X_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=None):
        self.losses_train.append(logs['loss'])
        test_loss = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        self.losses_test.append(test_loss)
        if (epoch + 1) % 50 == 0:
            print(f"Epoch {epoch + 1}, Train Loss: {logs['loss']:.4e}, Test Loss: {test_loss:.4e}")

callback = CustomCallback(X_test, y_test)
start_time = time.time()
model.fit(X_train, y_train, epochs=500, batch_size=16, verbose=0, callbacks=[callback])
elapsed_time = time.time() - start_time

train_predict = model.predict(X_train)
test_predict = model.predict(X_test)

train_predict = scaler.inverse_transform(train_predict)
test_predict = scaler.inverse_transform(test_predict)
y_train = scaler.inverse_transform([y_train])
y_test = scaler.inverse_transform([y_test])

mse = mean_squared_error(y_test[0], test_predict)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test[0], test_predict)
mae_train = mean_absolute_error(y_train[0], train_predict)

print(f"Validation MSE: {mse:.4e}")
print(f"Validation RMSE: {rmse:.4e}")
print(f"Validation MAE: {mae:.4e}")
print(f"Training MAE: {mae_train:.4e}")
print(f"Training time: {elapsed_time:.4f} seconds")

plt.figure(figsize=(10, 6))
plt.plot(data, label='Original Data', color='dodgerblue', linewidth=2)
train_plot = np.empty_like(data)
train_plot[:] = np.nan
train_plot[look_back:len(train_predict) + look_back] = train_predict
plt.plot(train_plot, label='Train Prediction', color='coral', linestyle='--', linewidth=2)
test_plot = np.empty_like(data)
test_plot[:] = np.nan
test_plot[len(train_predict) + (look_back * 2):] = test_predict
plt.plot(test_plot, label='Test Prediction', color='limegreen', linestyle='--', linewidth=2)
plt.title('XNet Prediction')
plt.xlabel('Time Steps')
plt.legend(loc='upper left')
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('XNet_Model2.png')
plt.savefig('XNet_Model2.pdf')
plt.savefig('XNet_noise_0.05.jpg')
plt.show()

plt.figure(figsize=(10, 6))
plt.plot(callback.losses_train, label='Training Loss', alpha=0.7)
plt.plot(callback.losses_test, label='Test Loss', alpha=0.7)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.yscale('log')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.title('Loss Curve')
plt.tight_layout()
plt.savefig('XNet_Model2_loss.png')
plt.savefig('XNet_Model2_loss.pdf')
plt.show()