# Copyright (c) 2025-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
##################################################################

import argparse
import copy
import os.path
import torch
import utils
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from Dataset_BasicMotions import Dataset_BasicMotions
from Dataset_ERing import Dataset_ERing
from Dataset_Heartbeat import Dataset_Heartbeat
from Dataset_JapaneseVowels import Dataset_JapaneseVowels
from Dataset_Libras import Dataset_Libras
from Dataset_NATOPS import Dataset_NATOPS
from Dataset_Life_Expectancy import Dataset_Life_Expectancy
from sklearn.model_selection import train_test_split
from Dataset_PEMS_SF import Dataset_PEMS_SF
from Dataset_RacketSports import Dataset_RacketSports


class LSTM_Model(nn.Module):
    def __init__(self, input_dimension, output_dimension):
        super().__init__()
        self.input_dimension = input_dimension
        self.output_dimension = output_dimension
        self.prediction_model_type = utils.Prediction_Model_Types.LSTM
        self.lstm = nn.LSTM(input_size=input_dimension, hidden_size=30, num_layers=1, batch_first=True)
        self.fc1 = nn.Linear(30, 60)
        self.act1 = nn.ReLU()
        self.fc2 = nn.Linear(60, output_dimension)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x, _ = self.lstm(x)

        x = x[:, -1, :]

        x = self.act1(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x

    def predict_proba(self, X):
        """
        simulate predict_proba() from scikit-learn
        """
        prob_1 = self(X)

        if self.output_dimension != 1:
            raise ValueError("Current Implementation only support binary classification.")

        result_vec = torch.cat((1 - prob_1, prob_1), dim=-1)

        return result_vec.detach().numpy()

    def predict(self, X):
        """
        simulate predict() from scikit-learn
        """

        probailities = self.predict_proba(X)

        pred_class = np.argmax(probailities, axis=-1)

        return pred_class


def model_train(model, X_train, y_train, X_val, y_val, save_model_path):
    print("Start training...")

    # loss function and optimizer
    loss_fn = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    batch_start_indexes = torch.arange(0, len(X_train), batch_size)

    # Hold the best model
    best_val_acc = - np.inf  # init to negative infinity
    best_weights = None

    # Holde the training results
    train_acc_epochs = []
    val_acc_epochs = []

    for epoch in range(num_epochs):  # for each epoch
        model.train()

        train_acc_current_epoch = 0
        for batch_start_index in batch_start_indexes:  # for each batch
            optimizer.zero_grad()

            # take a batch
            X_batch = X_train[batch_start_index:batch_start_index + batch_size]
            y_batch = y_train[batch_start_index:batch_start_index + batch_size]

            # forward pass
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)

            # backward pass
            loss.backward()

            # update weights
            optimizer.step()

            # training accuracy of the current batch
            train_batch_acc = (y_pred.round() == y_batch).float().mean()

            train_acc_current_epoch += float(train_batch_acc * (len(X_batch) / len(X_train)))

        # training accuracy of the current epoch
        train_acc_epochs.append(train_acc_current_epoch)

        # evaluate validation accuracy at end of each epoch
        model.eval()
        y_pred = model(X_val)

        val_acc = (y_pred.round() == y_val).float().mean()
        val_acc = float(val_acc)
        val_acc_epochs.append(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_weights = copy.deepcopy(model.state_dict())

        if epoch % training_plot_interval == 0:
            # Save training results
            plt.plot(train_acc_epochs, label="Training")
            plt.plot(val_acc_epochs, label="Validation")
            plt.legend()
            plt.xlabel("Epochs")
            plt.ylabel("Accuracy")
            plt.savefig(os.path.join(result_folder, 'accuracy.png'))
            plt.close('all')

    print("\nTraining is done.\n")

    # restore model and return best accuracy
    model.load_state_dict(best_weights)

    # save the best model
    torch.save(model.state_dict(), save_model_path)

    print("Best validation accuracy: {}".format(best_val_acc))
    utils.append_to_file(result_summary_file, "\nBest validation accuracy: \n{}\n".format(best_val_acc))

    # Save training results
    plt.plot(train_acc_epochs, label="Training")
    plt.plot(val_acc_epochs, label="Validation")
    plt.legend()
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.savefig(os.path.join(result_folder, 'accuracy.png'))
    plt.close('all')


if __name__ == "__main__":

    print("\n\n\nTraining LSTM starting...")

    parser = argparse.ArgumentParser()
    parser.add_argument('--result_folder', required=False, type=str, default="temp_results_LSTM")
    parser.add_argument('--result_folder_suffix', required=False, type=str, default="")
    parser.add_argument('--dataset_name', required=False, type=str, default=utils.Dataset_Names.life_expectancy,
                        choices=[utils.Dataset_Names.life_expectancy, utils.Dataset_Names.natops,
                                 utils.Dataset_Names.heartbeat, utils.Dataset_Names.racket_sports,
                                 utils.Dataset_Names.basic_motions, utils.Dataset_Names.ering,
                                 utils.Dataset_Names.japanese_vowels, utils.Dataset_Names.libras,
                                 utils.Dataset_Names.PEMS_SF])
    parser.add_argument('--random_seed', required=False, type=int, default=1)
    parser.add_argument('--learning_rate', required=False, type=float, default=0.001)
    parser.add_argument('--weight_decay', required=False, type=float, default=0.0)
    parser.add_argument('--num_epochs', required=False, type=int, default=3000)
    parser.add_argument('--batch_size', required=False, type=int, default=10)
    parser.add_argument('--split_percentage', required=False, type=float, default=0.7)
    args = parser.parse_args()

    result_folder = args.result_folder + args.result_folder_suffix
    print("result_folder: ", result_folder)

    random_seed = args.random_seed
    print("random_seed: ", random_seed)

    dataset_name = args.dataset_name
    print("dataset_name: ", dataset_name)

    learning_rate = args.learning_rate
    print("learning_rate: ", learning_rate)

    weight_decay = args.weight_decay
    print("weight_decay: ", weight_decay)

    num_epochs = args.num_epochs
    print("num_epochs: ", num_epochs)

    batch_size = args.batch_size
    print("batch_size: ", batch_size)

    split_percentage = args.split_percentage
    print("split_percentage: ", split_percentage)

    save_model_path = os.path.join(result_folder, "saved_LSTM_model")
    result_summary_file = os.path.join(result_folder, "result_summary.txt")
    config_file = os.path.join(result_folder, "configurations.txt")

    os.makedirs(result_folder, exist_ok=True)

    training_plot_interval = 100

    utils.log_config(config_file, args, result_summary_file)

    utils.set_random_seed(random_seed)

    ########################### Loading dataset ###########################
    data_path_root = os.getcwd()
    if dataset_name == utils.Dataset_Names.life_expectancy:
        dataset = Dataset_Life_Expectancy(data_path_root, None)
        X, y, _, _ = dataset.data_process_LE_one_hot()

        # label: Life Expectancy in 2015

        assert dataset.cate_index_start == dataset.num_of_features - len(dataset.categorical_features)

    elif dataset_name == utils.Dataset_Names.natops:
        dataset = Dataset_NATOPS(data_path_root, None)
        X, y, _, _ = dataset.load_dataset("train")

    elif dataset_name == utils.Dataset_Names.heartbeat:
        dataset = Dataset_Heartbeat(data_path_root, None)
        X, y, _, _ = dataset.load_dataset("train")

    elif dataset_name == utils.Dataset_Names.racket_sports:
        dataset = Dataset_RacketSports(data_path_root, None)
        X, y, _, _ = dataset.load_dataset("train")

    elif dataset_name == utils.Dataset_Names.basic_motions:
        dataset = Dataset_BasicMotions(data_path_root, None)
        X, y, _, _ = dataset.load_dataset("train")

    elif dataset_name == utils.Dataset_Names.ering:
        dataset = Dataset_ERing(data_path_root, None)
        X, y, _, _ = dataset.load_dataset("train")

    elif dataset_name == utils.Dataset_Names.japanese_vowels:
        dataset = Dataset_JapaneseVowels(data_path_root, None)
        X, y, _, _ = dataset.load_dataset("train")

    elif dataset_name == utils.Dataset_Names.libras:
        dataset = Dataset_Libras(data_path_root, None)
        X, y, _, _ = dataset.load_dataset("train")

    elif dataset_name == utils.Dataset_Names.PEMS_SF:
        dataset = Dataset_PEMS_SF(data_path_root, None)
        X, y, _, _ = dataset.load_dataset("train")

    else:
        raise ValueError("what??? dataset_name={}".format(dataset_name))

    ###############################################################################

    print(X.shape)
    print(y.shape)

    # Create an LSTM model
    model = LSTM_Model(dataset.input_dimension, dataset.output_dimension)
    print("total number of parameters: ", sum([x.reshape(-1).shape[0] for x in model.parameters()]))

    # train-test split: Hold out the test set for final model evaluation
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=split_percentage, shuffle=True)

    print("X_train.shape: ", X_train.shape)
    print("y_train.shape: ", y_train.shape)
    print("X_test.shape: ", X_test.shape)
    print("y_test.shape: ", y_test.shape)

    # Start training
    model_train(model, X_train, y_train, X_test, y_test, save_model_path)

    # Load the saved model
    model = None
    model = LSTM_Model(dataset.input_dimension, dataset.output_dimension)
    model.load_state_dict(torch.load(save_model_path))
    model.eval()
