# -*- coding: utf-8 -*-
"""
Created on Tue Sep 10 20:59:18 2024

@author: kernel
"""
import os
import time
import argparse
new_path = r"...\neuralforecast-main"
os.chdir(new_path)

from neuralforecast.utils import AirPassengersDF
from neuralforecast import NeuralForecast
from neuralforecast.models import LSTM, NHITS, RNN, KAN
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['figure.dpi'] = 600

# from sklearn.metrics import mean_squared_error, mean_absolute_error

df = pd.read_csv("data_generator\model_13_0.csv")
data = df['y'].values.astype('float32')
# df['ds'] = pd.to_datetime(df['ds'])


df['ds'] = range(1, len(df) + 1)
horizon = 1
input_horizon = 5
Y_df = df # Defined in neuralforecast.utils
Y = Y_df
Y_df = Y_df.head(round(len(Y)*0.8))  






# Try different hyperparmeters to improve accuracy.
models = [
         # LSTM(h=horizon,                    # Forecast horizon
         #        max_steps=500,                # Number of steps to train
         #        scaler_type='standard',       # Type of scaler to normalize data
         #        encoder_hidden_size=64),     # Defines the number of hidden units of each layer of the MLP decoder
         # NHITS(h=horizon,                   # Forecast horizon
         #        input_size=2 * horizon,      # Length of input sequence
         #        max_steps=100,               # Number of steps to train
         #        n_freq_downsample=[2, 1, 1])
         KAN(h=horizon,
             max_steps=500, 
             input_size = input_horizon, 
             hidden_size=64,
             random_seed= 5,
             scaler_type='robust')
          ]
nf = NeuralForecast(models=models, freq=1)
start_time = time.time()
nf.fit(df=Y_df)
elapsed_time = time.time() - start_time


Y_hat_df = nf.predict()
Y_hat_df = Y_hat_df.reset_index()
Y_hat_df.head()


merged_df = pd.merge(Y, Y_hat_df[['ds', 'KAN']], on='ds', how='outer')
fig, ax = plt.subplots(1, 1, figsize=(20, 7))
merged_df.set_index('ds')[['y', 'KAN']].plot(ax=ax, linewidth=2)
ax.set_title('AirPassengers Forecast', fontsize=22)
ax.set_ylabel('Monthly Passengers', fontsize=20)
ax.set_xlabel('Timestamp [t]', fontsize=20)
ax.legend(prop={'size': 15})
ax.grid()
plt.show()



start_index = 0  
end_index = len(Y) - input_horizon 


predictions = pd.DataFrame()


for i in range(start_index, end_index, horizon):
  
    window_data = Y.iloc[:i + input_horizon]
    Y_hat_df_pre = nf.predict(window_data)
    predictions = pd.concat([predictions, Y_hat_df_pre])

y_predict = predictions[['KAN']].values 

train_plot = np.empty_like(data)
train_plot[:] = np.nan
train_plot[input_horizon:len(Y_df)] = y_predict[0:len(Y_df)-input_horizon].flatten()

test_plot = np.empty_like(data)
test_plot[:] = np.nan
test_plot[- horizon*35 :len(data)] = y_predict[-35 :].flatten()
plt.figure(figsize=(10, 6))
plt.plot(data, label='Original Data', color='dodgerblue', linewidth=2)
plt.plot(train_plot, label='Train Prediction', color='coral', linestyle='--', linewidth=2)
plt.plot(test_plot, label='Test Prediction', color='limegreen', linestyle='--', linewidth=2)
plt.title('[5,64,1]KAN', fontsize=16)
# plt.ylabel('Model 1', fontsize=14)
plt.xlabel('Time Steps', fontsize=14)
plt.legend(loc='upper left', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

fig, ax = plt.subplots(1, 1, figsize=(20, 7))
merged_df_pre = pd.merge(Y, predictions[['ds', 'KAN']], on='ds', how='outer')
merged_df_pre.set_index('ds')[['y', 'KAN']].plot(ax=ax, linewidth=2)
ax.set_title('AirPassengers Forecast', fontsize=22)
ax.set_ylabel('Monthly Passengers', fontsize=20)
ax.set_xlabel('Timestamp [t]', fontsize=20)
ax.legend(prop={'size': 15})
ax.grid()
plt.show()



mse_y_train = Y_df['y'][input_horizon:len(Y_df)]  
mse_kan_train = y_predict[0:len(Y_df)-input_horizon].flatten()  


mae_train = np.mean(np.abs(mse_y_train - mse_kan_train))
print(f"Mean Absolute Error (MAE) on Training Data: {mae_train:.4e}")


mse_y_val = Y['y'][-35:]  
mse_kan_val = y_predict[-35:].flatten()  

mse_val = np.mean((mse_y_val - mse_kan_val) ** 2)
rmse_val = np.sqrt(mse_val)
mae_val = np.mean(np.abs(mse_y_val - mse_kan_val))

print(f"Mean Squared Error (MSE) on Validation Data: {mse_val:.4e}")
print(f"Root Mean Squared Error (RMSE) on Validation Data: {rmse_val:.4e}")
print(f"Mean Absolute Error (MAE) on Validation Data: {mae_val:.4e}")


print(f"Training time: {elapsed_time:.4f} seconds")





