# Copyright (c) 2025-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
##################################################################

import argparse
import os.path
import torch
import utils
from Dataset_PEMS_SF import Dataset_PEMS_SF


class Rule_Model:
    def __init__(self, dataset, version, dataset_name):
        super().__init__()

        self.prediction_model_type = utils.Prediction_Model_Types.Rule

        self.dataset = dataset

        self.version = version

        self.dataset_name = dataset_name

    def predict(self, X):

        assert torch.is_tensor(X)

        assert X.shape[1] == self.dataset.length_of_sequence
        assert X.shape[2] == self.dataset.num_of_features

        if self.dataset_name == utils.Dataset_Names.life_expectancy:
            predictions = self.predict_LE(X)

        elif self.dataset_name == utils.Dataset_Names.natops:
            predictions = self.predict_NATOPS(X)

        elif self.dataset_name == utils.Dataset_Names.heartbeat:
            predictions = self.predict_Heartbeat(X)

        elif self.dataset_name == utils.Dataset_Names.racket_sports:
            predictions = self.predict_RacketSports(X)

        elif self.dataset_name == utils.Dataset_Names.basic_motions:
            predictions = self.predict_BasicMotions(X)

        elif self.dataset_name == utils.Dataset_Names.ering:
            predictions = self.predict_ERing(X)

        elif self.dataset_name == utils.Dataset_Names.japanese_vowels:
            predictions = self.predict_JapaneseVowels(X)

        elif self.dataset_name == utils.Dataset_Names.libras:
            predictions = self.predict_Libras(X)

        elif self.dataset_name == utils.Dataset_Names.PEMS_SF:
            predictions = self.predict_PEMS_SF(X)

        else:
            raise ValueError("what??? dataset_name={}".format(self.dataset_name))

        return predictions

    def predict_LE(self, X):
        """
        simulate predict() from scikit-learn
        """

        least_Developed_index = self.dataset.feature_names.index('Least_Developed')
        gdp_index = self.dataset.feature_names.index('GDP_per_capita')
        health_index = self.dataset.feature_names.index('Health_expenditure')
        defecation_index = self.dataset.feature_names.index('People_practicing_open_defecation')
        water_index = self.dataset.feature_names.index('People_using_at_least_basic_drinking_water_services')

        consecutive_years = 5

        predictions = torch.zeros(X.shape[0])

        if self.version == "1":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_years:, least_Developed_index] == 0) and \
                        torch.all(x[-consecutive_years:, gdp_index] > 0) and \
                        torch.all(x[-consecutive_years:, health_index] > 0) and \
                        torch.all(x[-consecutive_years:, defecation_index] < 0) and \
                        torch.all(x[-consecutive_years:, water_index] > 0):
                    predictions[i] = 1

        elif self.version == "2":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_years:, least_Developed_index] == 0) and \
                        (torch.all(x[-consecutive_years:, gdp_index] > 0) or \
                         torch.all(x[-consecutive_years:, health_index] > 0)) and \
                        torch.all(x[-consecutive_years:, defecation_index] < 0) and \
                        torch.all(x[-consecutive_years:, water_index] > 0):
                    predictions[i] = 1

        else:
            raise ValueError("Version {} is not implemented.".format(self.version))

        return predictions

    def predict_NATOPS(self, X):
        """
        simulate predict() from scikit-learn
        """

        feature1_index = self.dataset.feature_names.index('Hand tip left, X coordinate')
        feature2_index = self.dataset.feature_names.index('Hand tip right, X coordinate')

        consecutive_timesteps = 10

        predictions = torch.zeros(X.shape[0])

        if self.version == "1":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0):
                    predictions[i] = 1

        elif self.version == "2":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0):
                    predictions[i] = 1

        else:
            raise ValueError("Version {} is not implemented.".format(self.version))

        return predictions

    def predict_Heartbeat(self, X):
        """
        simulate predict() from scikit-learn
        """

        feature1_index = self.dataset.feature_names.index('feature_1')
        feature2_index = self.dataset.feature_names.index('feature_2')
        feature3_index = self.dataset.feature_names.index('feature_3')

        consecutive_timesteps = 5

        predictions = torch.zeros(X.shape[0])

        if self.version == "1":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature3_index] > 0):
                    predictions[i] = 1

        elif self.version == "2":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature3_index] > 0):
                    predictions[i] = 1

        else:
            raise ValueError("Version {} is not implemented.".format(self.version))

        return predictions


    def predict_RacketSports(self, X):
        """
        simulate predict() from scikit-learn
        """
        feature1_index = self.dataset.feature_names.index('feature_1')
        feature2_index = self.dataset.feature_names.index('feature_5')

        consecutive_timesteps = 5

        predictions = torch.zeros(X.shape[0])

        if self.version == "1":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0):
                    predictions[i] = 1

        elif self.version == "2":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0):
                    predictions[i] = 1

        else:
            raise ValueError("Version {} is not implemented.".format(self.version))

        return predictions

    def predict_BasicMotions(self, X):
        """
        simulate predict() from scikit-learn
        """

        feature1_index = self.dataset.feature_names.index('feature_1')
        feature2_index = self.dataset.feature_names.index('feature_3')
        feature3_index = self.dataset.feature_names.index('feature_6')

        consecutive_timesteps = 10

        predictions = torch.zeros(X.shape[0])

        if self.version == "1":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature3_index] > 0):
                    predictions[i] = 1

        elif self.version == "2":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature3_index] > 0):
                    predictions[i] = 1

        else:
            raise ValueError("Version {} is not implemented.".format(self.version))

        return predictions

    def predict_ERing(self, X):
        """
        simulate predict() from scikit-learn
        """
        feature1_index = self.dataset.feature_names.index('feature_2')
        feature2_index = self.dataset.feature_names.index('feature_3')

        consecutive_timesteps = 10

        predictions = torch.zeros(X.shape[0])

        if self.version == "1":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0):
                    predictions[i] = 1

        elif self.version == "2":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0):
                    predictions[i] = 1

        else:
            raise ValueError("Version {} is not implemented.".format(self.version))

        return predictions

    def predict_JapaneseVowels(self, X):
        """
        simulate predict() from scikit-learn
        """

        feature1_index = self.dataset.feature_names.index('feature_1')
        feature2_index = self.dataset.feature_names.index('feature_6')
        feature3_index = self.dataset.feature_names.index('feature_12')

        consecutive_timesteps = 20

        predictions = torch.zeros(X.shape[0])

        if self.version == "1":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature3_index] > 0):
                    predictions[i] = 1

        elif self.version == "2":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature3_index] > 0):
                    predictions[i] = 1

        else:
            raise ValueError("Version {} is not implemented.".format(self.version))

        return predictions


    def predict_Libras(self, X):
        """
        simulate predict() from scikit-learn
        """

        feature1_index = self.dataset.feature_names.index('feature_1')
        feature2_index = self.dataset.feature_names.index('feature_2')

        consecutive_timesteps = 20

        predictions = torch.zeros(X.shape[0])

        if self.version == "1":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) and \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0):
                    predictions[i] = 1

        elif self.version == "2":
            for i, x in enumerate(X):
                if torch.all(x[-consecutive_timesteps:, feature1_index] > 0) or \
                        torch.all(x[-consecutive_timesteps:, feature2_index] > 0):
                    predictions[i] = 1

        else:
            raise ValueError("Version {} is not implemented.".format(self.version))

        return predictions


    def predict_PEMS_SF(self, X):
        """
        simulate predict() from scikit-learn
        """

        feature1_index = self.dataset.feature_names.index('feature_1')
        feature2_index = self.dataset.feature_names.index('feature_100')
        feature3_index = self.dataset.feature_names.index('feature_300')

        predictions = torch.zeros(X.shape[0])

        if self.version == "1":
            consecutive_timesteps = 50
            for i, x in enumerate(X):
                if (torch.sum(x[-consecutive_timesteps:, feature1_index])
                        + torch.sum(x[-consecutive_timesteps:, feature2_index])
                        + torch.sum(x[-consecutive_timesteps:, feature3_index])) > 10:
                    predictions[i] = 1

        elif self.version == "2":
            consecutive_timesteps = 100
            for i, x in enumerate(X):
                if (torch.sum(x[-consecutive_timesteps:, feature1_index])
                    + torch.sum(x[-consecutive_timesteps:, feature2_index])
                    + torch.sum(x[-consecutive_timesteps:, feature3_index])) > 10:
                    predictions[i] = 1

        else:
            raise ValueError("Version {} is not implemented.".format(self.version))

        return predictions


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--result_folder', required=False, type=str, default="temp_results_RL")
    parser.add_argument('--result_folder_suffix', required=False, type=str, default="")
    parser.add_argument('--random_seed', required=False, type=int, default=1)
    parser.add_argument('--max_episodes', required=False, type=int, default=1000)
    parser.add_argument('--lambda_proximity', required=False, type=float, default=0.0)
    parser.add_argument('--lambda_entropy', required=False, type=float, default=0.0)
    parser.add_argument('--max_number_of_interventions', required=False, type=int, default=100)
    parser.add_argument('--intervention_type', required=False, type=str, default="drifting",
                        choices=["drifting", "point"])
    parser.add_argument('--feature_extreme_values', required=False, type=int, default=1)
    parser.add_argument('--learning_rate', required=False, type=float, default=0.001)
    parser.add_argument('--weight_decay', required=False, type=float, default=0.001)
    args = parser.parse_args()

    random_seed = args.random_seed
    print("random_seed: ", random_seed)

    max_episodes = args.max_episodes
    print("max_episodes: ", max_episodes)

    lambda_proximity = args.lambda_proximity
    print("lambda_proximity: ", lambda_proximity)

    feature_entropy = None
    timestep_entropy = None
    lambda_entropy = args.lambda_entropy
    print("lambda_entropy: ", lambda_entropy)

    max_number_of_interventions = args.max_number_of_interventions
    print("max_number_of_interventions: ", max_number_of_interventions)

    intervention_type = args.intervention_type
    print("intervention_type: ", intervention_type)

    feature_extreme_values = args.feature_extreme_values
    print("feature_extreme_values: ", feature_extreme_values)

    learning_rate = args.learning_rate
    print("learning_rate: ", learning_rate)

    weight_decay = args.weight_decay
    print("weight_decay: ", weight_decay)

    result_folder = args.result_folder + args.result_folder_suffix
    print("result_folder: ", result_folder)

    save_model_path = os.path.join(result_folder, "saved_LSTM_model")
    result_summary_file = os.path.join(result_folder, "result_summary.txt")
    config_file = os.path.join(result_folder, "configurations.txt")

    os.makedirs(result_folder, exist_ok=True)

    training_plot_interval = 100

    utils.log_config(config_file, args, result_summary_file)

    utils.set_random_seed(random_seed)

    data_path_root = os.getcwd()
    
    dataset = Dataset_PEMS_SF(data_path_root, None)
    X_one_hot, y, X_max, X_min = dataset.load_dataset("test")
    
    X_not_one_hot = X_one_hot.clone()
    
    model = Rule_Model(dataset, version="2", dataset_name=utils.Dataset_Names.PEMS_SF)

    y_pred = model.predict(X_not_one_hot)

    print("y_pred: ", y_pred)

    print("Percentage of prediction of 1s: ", (y_pred.round().sum() / len(y_pred)).item())
