# -*- coding: utf-8 -*-
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error

plt.rcParams['figure.dpi'] = 600

seed = 3
np.random.seed(seed)
tf.random.set_seed(seed)

class ComplexLinear(tf.keras.Model):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = tf.Variable(tf.random.normal([input_dim, output_dim], stddev=0.1), trainable=True)
        self.real_bias = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.01), trainable=True)
        self.imag_bias = tf.Variable(tf.random.normal([output_dim], mean=-0.01, stddev=0.1), trainable=True)
        self.lambda_k1 = tf.Variable(tf.random.normal([output_dim], mean=0.01, stddev=0.0001), trainable=True)
        self.lambda_k2 = tf.Variable(tf.random.normal([output_dim], mean=0.05, stddev=0.0001), trainable=True)

    def call(self, x):
        x_real = tf.math.real(x) if x.dtype.is_complex else x
        x_imag = tf.zeros_like(x)
        real_1 = tf.linalg.matmul(x_real, self.real_weights) + self.real_bias
        real = tf.square(real_1)
        complex_output = real + self.imag_bias**2
        output_1 = (real_1 / complex_output) * self.lambda_k1
        output_2 = (self.imag_bias / complex_output) * self.lambda_k2
        return output_1 + output_2

class Linear(tf.keras.layers.Layer):
    def __init__(self, input_dim, output_dim):
        super(Linear, self).__init__()
        self.input_dim = input_dim

    def call(self, inputs):
        reshaped = tf.reshape(inputs, (-1, self.input_dim, 1))
        return tf.reduce_mean(reshaped, axis=1)

df_raw = pd.read_csv('model_13_0_mixed.csv')
data = df_raw.iloc[:, 0].replace(',', '', regex=True).astype('float32').values.reshape(-1, 1)
df = pd.DataFrame(data, columns=['Signal'])

plt.figure(figsize=(10, 6))
plt.plot(df)
plt.title('Time Series Data')
plt.xlabel('Time Steps')
plt.ylabel('Value')
plt.show()

scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(df)

train_size = int(len(scaled_data) * 0.8)
train, test = scaled_data[:train_size], scaled_data[train_size:]

def create_dataset(dataset, look_back=10):
    X, Y = [], []
    for i in range(len(dataset) - look_back):
        X.append(dataset[i:i + look_back, 0])
        Y.append(dataset[i + look_back, 0])
    return np.array(X), np.array(Y)

look_back = 5
X_train, y_train = create_dataset(train, look_back)
X_test, y_test = create_dataset(test, look_back)

num_hiddens = [look_back, 50, 1]
model = Sequential([
    ComplexLinear(num_hiddens[0], num_hiddens[1]),
    Linear(num_hiddens[1], num_hiddens[2])
])

model.compile(optimizer='adam', loss='mean_squared_error')

def lr_schedule(epoch):
    if epoch < 100:
        return 1e-3
    elif epoch < 200:
        return 5e-4
    elif epoch < 300:
        return 2e-4
    elif epoch < 350:
        return 1e-4
    elif epoch < 400:
        return 5e-5
    elif epoch < 450:
        return 5e-6
    else:
        return 1e-6

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.losses_train = []
        self.losses_test = []
        self.X_test = X_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=None):
        self.losses_train.append(logs['loss'])
        test_loss = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        self.losses_test.append(test_loss)
        if (epoch + 1) % 50 == 0:
            print(f"Epoch {epoch + 1}, Train Loss: {logs['loss']:.4e}, Test Loss: {test_loss:.4e}")

callback = CustomCallback(X_test, y_test)
start_time = time.time()
model.fit(X_train, y_train, epochs=500, batch_size=16, verbose=0, callbacks=[callback])
elapsed_time = time.time() - start_time

train_predict = model.predict(X_train)
test_predict = model.predict(X_test)

train_predict = scaler.inverse_transform(train_predict)
test_predict = scaler.inverse_transform(test_predict)
y_train = scaler.inverse_transform([y_train])
y_test = scaler.inverse_transform([y_test])

mse = mean_squared_error(y_test[0], test_predict)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test[0], test_predict)
mae_train = mean_absolute_error(y_train[0], train_predict)

print(f"Validation MSE: {mse:.4e}")
print(f"Validation RMSE: {rmse:.4e}")
print(f"Validation MAE: {mae:.4e}")
print(f"Training MAE: {mae_train:.4e}")
print(f"Training Time: {elapsed_time:.2f} seconds")

plt.figure(figsize=(10, 6))
plt.plot(data, label='Original Data', color='dodgerblue', linewidth=2)
train_plot = np.full_like(data, np.nan)
train_plot[look_back:len(train_predict) + look_back] = train_predict
plt.plot(train_plot, label='Train Prediction', color='coral', linestyle='--', linewidth=2)
test_plot = np.full_like(data, np.nan)
test_plot[len(train_predict) + (look_back * 2):] = test_predict
plt.plot(test_plot, label='Test Prediction', color='limegreen', linestyle='--', linewidth=2)
plt.title('XNet Prediction (Noise Level: 0)')
plt.xlabel('Time Steps')
plt.legend(loc='upper left')
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('XNet_Model2.png')
plt.savefig('XNet_Model2.pdf')
plt.savefig('XNet_noise_0.jpg')
plt.show()

plt.figure(figsize=(10, 6))
plt.plot(callback.losses_train, label='Training Loss', alpha=0.7)
plt.plot(callback.losses_test, label='Test Loss', alpha=0.7)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.yscale('log')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.title('Loss Curve')
plt.tight_layout()
plt.savefig('XNet_Model2_loss.png')
plt.savefig('XNet_Model2_loss.pdf')
plt.show()