# -*- coding: utf-8 -*-
"""
Created on Mon Aug  5 11:29:14 2024

@author: kernel
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Layer
from tensorflow.keras.initializers import glorot_uniform, orthogonal, zeros
from tensorflow.keras import backend as K
import tensorflow as tf

class ComplexLinear(tf.keras.Model):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.real_weights = tf.Variable(tf.random.normal([input_dim, output_dim], mean=0.1, stddev=0.2, dtype=tf.float32), trainable=True)
        self.real_bias = tf.Variable(tf.random.normal([output_dim], mean=0, stddev=0.01, dtype=tf.float32), trainable=True)
        self.imag_bias = tf.Variable(tf.random.normal([output_dim], mean=-0.01, stddev=0.01, dtype=tf.float32), trainable=True)
        self.lambda_k1 = tf.Variable(tf.random.normal([output_dim], mean=0.5, stddev=0.01, dtype=tf.float32), trainable=True)
        self.lambda_k2 = tf.Variable(tf.random.normal([output_dim], mean=0.5, stddev=0.01, dtype=tf.float32), trainable=True)
        
    
    def call(self, x):
        if tf.dtypes.as_dtype(x.dtype).is_complex:
            x_real = tf.math.real(x)
            x_imag = tf.math.imag(x)
        else:
            x_real = x
            x_imag = tf.zeros_like(x)
        
        real_1 = tf.linalg.matmul(x_real, self.real_weights) + self.real_bias 
        real = tf.square(real_1)
        complex_output = real + self.imag_bias**2
        complex_output_1 = real_1 / complex_output
        alpha_max = tf.reduce_max(tf.abs(complex_output_1))
        output_1 = complex_output_1  * self.lambda_k1
        
        complex_output_2 = self.imag_bias / complex_output
        beta_max = tf.reduce_max(tf.abs(complex_output_2))
        output_2 = complex_output_2  * self.lambda_k2 
        
        x = output_1 + output_2
        return x



class CustomLSTMCell(Layer):
    def __init__(self, units, input_dim, **kwargs):
        self.units = units
        self.state_size = [units, units]
        super(CustomLSTMCell, self).__init__(**kwargs)
        
        # 添加 ComplexLinear 层
        self.complex_linear = ComplexLinear(input_dim, units * 4)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units * 4),
                                      initializer=glorot_uniform(),
                                      name='kernel')
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),
                                                initializer=orthogonal(),
                                                name='recurrent_kernel')
        self.bias = self.add_weight(shape=(self.units * 4,),
                                    initializer=zeros(),
                                    name='bias')
        self.built = True

    def call(self, inputs, states):
        h_tm1, c_tm1 = states  # previous memory state and carry state

        z = self.complex_linear(inputs) + K.dot(h_tm1, self.recurrent_kernel) + self.bias
        
        z0, z1, z2, z3 = tf.split(z, num_or_size_splits=4, axis=1)

        i = K.sigmoid(z0)
        f = K.sigmoid(z1)
        c = f * c_tm1 + i * z2
        o = K.sigmoid(z3)
        h = o * c

        return h, [h, c]

class CustomLSTM(Layer):
    def __init__(self, units, return_sequences=False, return_state=False, **kwargs):
        self.units = units
        self.return_sequences = return_sequences
        self.return_state = return_state
        super(CustomLSTM, self).__init__(**kwargs)
        self.cell = CustomLSTMCell(units, input_dim=1)

    def build(self, input_shape):
        self.cell.build(input_shape)
        self.built = True

    def call(self, inputs, initial_state=None):
        if initial_state is None:
            initial_state = self.get_initial_state(inputs)
        
        states = initial_state
        outputs = []

        for t in range(inputs.shape[1]):
            output, states = self.cell(inputs[:, t, :], states)
            if self.return_sequences:
                outputs.append(output)

        if self.return_sequences:
            outputs = tf.stack(outputs, axis=1)
        else:
            outputs = output

        if self.return_state:
            return [outputs] + states
        else:
            return outputs

    def get_initial_state(self, inputs):
        batch_size = tf.shape(inputs)[0]
        initial_state = [tf.zeros((batch_size, self.units)) for _ in range(2)]
        return initial_state


df = pd.read_csv("...\model_13_0.05.csv")
data = df.values.astype('float32')


scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data)


train_size = int(len(scaled_data) * 0.8)
train, test = scaled_data[0:train_size], scaled_data[train_size:len(scaled_data)]

def create_dataset(dataset, look_back=1):
    X, Y = [], []
    for i in range(len(dataset) - look_back):
        a = dataset[i:(i + look_back), 0]
        X.append(a)
        Y.append(dataset[i + look_back, 0])
    return np.array(X), np.array(Y)

look_back = 5
X_train, y_train = create_dataset(train, look_back)
X_test, y_test = create_dataset(test, look_back)

X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))

model = Sequential()
model.add(CustomLSTM(10, input_shape=(look_back, 1)))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.losses = []  

    def on_epoch_end(self, epoch, logs=None):
        self.losses.append(logs['loss'])
        if (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch + 1}, Loss: {logs['loss']:.4e}")  


callback = CustomCallback()
model.fit(X_train, y_train, epochs=500, batch_size=32, verbose=0, callbacks=[callback])


train_predict = model.predict(X_train)
test_predict = model.predict(X_test)


train_predict = scaler.inverse_transform(train_predict)
test_predict = scaler.inverse_transform(test_predict)
y_train = scaler.inverse_transform([y_train])
y_test = scaler.inverse_transform([y_test])


mse = mean_squared_error(y_test[0], test_predict)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test[0], test_predict)
mae_train = mean_absolute_error(y_train[0], train_predict)

print(f"Mean Squared Error (MSE) of validation data: {mse:.4e}")
print(f"Root Mean Squared Error (RMSE) of validation data: {rmse:.4e}")
print(f"LSTM Mean Absolute Error (MAE) of validation data: {mae:.4e}")
print(f"LSTM Mean Absolute Error (MAE) of training data: {mae_train:.4e}")


plt.figure(figsize=(12, 6))
plt.plot(df, label='Original data')
train_plot = np.empty_like(data)
train_plot[:, :] = np.nan
train_plot[look_back:len(train_predict) + look_back, :] = train_predict
plt.plot(train_plot, label='Train prediction')

test_plot = np.empty_like(data)
test_plot[:, :] = np.nan
test_plot[len(train_predict) + (look_back * 2):len(data), :] = test_predict
plt.plot(test_plot, label='Test prediction')
plt.ylabel('X-LSTM', fontsize=14)
plt.legend()
plt.show()



plt.figure(figsize=(10, 6))
plt.plot(callback.losses, label='Training Loss')
plt.title('Loss During Training')
plt.xlabel('Epoch')
plt.ylabel('Loss of XLSTM')
plt.legend()
plt.show()


np.save('XLSTM_loss.npy', np.array(callback.losses))