import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Layer
from tensorflow.keras.initializers import glorot_uniform, orthogonal, zeros
from tensorflow.keras import backend as K
import tensorflow as tf
plt.rcParams['figure.dpi'] = 600
class CustomLSTMCell(Layer):
    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = [units, units]
        super(CustomLSTMCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units * 4),
                                      initializer=glorot_uniform(),
                                      name='kernel')
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),
                                                initializer=orthogonal(),
                                                name='recurrent_kernel')
        self.bias = self.add_weight(shape=(self.units * 4,),
                                    initializer=zeros(),
                                    name='bias')
        self.built = True

    def call(self, inputs, states):
        h_tm1, c_tm1 = states  # previous memory state and carry state

        z = K.dot(inputs, self.kernel) + K.dot(h_tm1, self.recurrent_kernel) + self.bias

        z0, z1, z2, z3 = tf.split(z, num_or_size_splits=4, axis=1)

        i = K.sigmoid(z0)
        f = K.sigmoid(z1)
        c = f * c_tm1 + i * K.tanh(z2)
        o = K.sigmoid(z3)
        h = o * K.tanh(c)

        return h, [h, c]

class CustomLSTM(Layer):
    def __init__(self, units, return_sequences=False, return_state=False, **kwargs):
        self.units = units
        self.return_sequences = return_sequences
        self.return_state = return_state
        super(CustomLSTM, self).__init__(**kwargs)
        self.cell = CustomLSTMCell(units)

    def build(self, input_shape):
        self.cell.build(input_shape)
        self.built = True

    def call(self, inputs, initial_state=None):
        if initial_state is None:
            initial_state = self.get_initial_state(inputs)
        
        states = initial_state
        outputs = []

        for t in range(inputs.shape[1]):
            output, states = self.cell(inputs[:, t, :], states)
            if self.return_sequences:
                outputs.append(output)

        if self.return_sequences:
            outputs = tf.stack(outputs, axis=1)
        else:
            outputs = output

        if self.return_state:
            return [outputs] + states
        else:
            return outputs

    def get_initial_state(self, inputs):
        batch_size = tf.shape(inputs)[0]
        initial_state = [tf.zeros((batch_size, self.units)) for _ in range(2)]
        return initial_state

df = pd.read_csv('...\data_generator\model_13_0.05.csv')
data = df.iloc[:, 0].replace(',', '', regex=True).astype('float32').values
data = data.reshape(-1, 1)
df = data

scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data)

train_size = int(len(scaled_data) * 0.8)
train, test = scaled_data[0:train_size], scaled_data[train_size:len(scaled_data)]

def create_dataset(dataset, look_back=10):
    X, Y = [], []
    for i in range(len(dataset) - look_back):
        a = dataset[i:(i + look_back), 0]
        X.append(a)
        Y.append(dataset[i + look_back, 0])
    return np.array(X), np.array(Y)

look_back = 5
X_train, y_train = create_dataset(train, look_back)
X_test, y_test = create_dataset(test, look_back)

X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))

model = Sequential()
model.add(CustomLSTM(10, input_shape=(look_back, 1)))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.losses = [] 
        self.train_losses = []  
        self.test_losses = []  
        self.X_test = X_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=None):
        self.losses.append(logs['loss'])
        
        test_pred = self.model.predict(self.X_test, verbose=0)
        test_loss = mean_squared_error(self.y_test, test_pred)
        self.test_losses.append(test_loss)
        
        if (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch + 1}, Training Loss: {logs['loss']:.4e}, Test Loss: {test_loss:.4e}")


callback = CustomCallback(X_test, y_test)


start_time = time.time()
model.fit(X_train, y_train, epochs=500, batch_size=32, verbose=0, callbacks=[callback])
elapsed_time = time.time() - start_time



train_predict = model.predict(X_train)
test_predict = model.predict(X_test)


train_predict = scaler.inverse_transform(train_predict)
test_predict = scaler.inverse_transform(test_predict)
y_train = scaler.inverse_transform([y_train])
y_test = scaler.inverse_transform([y_test])

# 计算误差
mse = mean_squared_error(y_test[0], test_predict)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test[0], test_predict)
mae_train = mean_absolute_error(y_train[0], train_predict)

print(f"Mean Squared Error (MSE) of validation data: {mse:.4e}")
print(f"Root Mean Squared Error (RMSE) of validation data: {rmse:.4e}")
print(f"LSTM Mean Absolute Error (MAE) of validation data: {mae:.4e}")
print(f"LSTM Mean Absolute Error (MAE) of training data: {mae_train:.4e}")
print(f"Training time: {elapsed_time:.4f}")


plt.figure(figsize=(10, 6))
plt.plot(data, label='Original Data', color='dodgerblue', linewidth=2)
train_plot = np.empty_like(data)
train_plot[:] = np.nan
train_plot[look_back:len(train_predict) + look_back] = train_predict
plt.plot(train_plot, label='Train Prediction', color='coral', linestyle='--', linewidth=2)
test_plot = np.empty_like(data)
test_plot[:] = np.nan
test_plot[len(train_predict) + (look_back * 2):len(data)] = test_predict
plt.plot(test_plot, label='Test Prediction', color='limegreen', linestyle='--', linewidth=2)
plt.title('LSTM', fontsize=16)
plt.xlabel('Time Steps', fontsize=14)
plt.legend(loc='upper left', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.savefig(r'...\LSTM_figures\LSTM_Model2.png')  
plt.savefig(r'...\LSTM_figures\LSTM_Model2.pdf')  
plt.show()




plt.figure(figsize=(10, 6))
plt.plot(callback.losses, label='Train', color='dodgerblue', linewidth=2)
plt.plot(callback.test_losses, label='Test', color='coral', linewidth=2)
plt.title('LSTM', fontsize=16)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('MSE', fontsize=14)
plt.yscale('log')  # 如果损失值变化范围较大，可以使用对数刻度
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(loc='upper right', fontsize=12)
plt.savefig(r'...\LSTM_figures\LSTM_Model2_loss.png')  
plt.savefig(r'...\LSTM_figures\LSTM_Model2_loss.pdf')  
plt.show()
np.save('...\Loss\LSTM_loss_model2.npy', np.array(callback.losses))
