# Copyright (c) 2025-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
##################################################################

import time
import datetime
import random
import itertools
import matplotlib.colors as mcolors
import numpy as np
import torch
import os
import torch.nn.functional as F
from sklearn.neighbors import LocalOutlierFactor


class Prediction_Model_Types:
    LSTM = "LSTM"
    Rule = "Rule"
    RandomForest = "RandomForest"
    KNN = "KNN"


class Plausibility:
    plausible = 1
    implausible = -1


class Dataset_Names:
    life_expectancy = "LifeExpectancy"
    natops = "NATOPS"
    heartbeat = "Heartbeat"
    racket_sports = "RacketSports"
    basic_motions = "BasicMotions"
    ering = "ERing"
    japanese_vowels = "JapaneseVowels"
    libras = "Libras"
    PEMS_SF = "PEMS_SF"


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def append_to_file(target_file, content_to_append):
    f = open(target_file, "a")
    f.write(content_to_append)
    f.close()


def overwrite_file(target_file, content_to_write):
    f = open(target_file, "w")
    f.write(content_to_write)
    f.close()


def create_color_maker_combo():
    # create colour-maker combinations since there are too many features to use colours alone
    markers = ['o', 'x', 'v', "<", ">", "1", "2"]
    colors = mcolors.TABLEAU_COLORS
    mc_list = list(itertools.product(markers, colors))
    return mc_list


def log_config(config_file, args, result_summary_file):
    append_to_file(config_file, "\n#########################\n")
    append_to_file(result_summary_file, "\n#########################\n")

    # record the current time
    current_time = datetime.datetime.now()
    append_to_file(config_file, "\ncurrent_time: \n{}\n".format(current_time))
    append_to_file(result_summary_file, "\ncurrent_time: \n{}\n".format(current_time))

    # record the configurations
    append_to_file(config_file, "\nargs: \n{}\n".format(args))
    append_to_file(result_summary_file, "\nargs: \n{}\n".format(args))


def convert_to_one_hot(X_tensor, dataset, X_cate_is_standardized):
    """
    For some Xs, copy the continuous features and convert categorical features to one-hot encoding.

    Use this function before calling neural networks
    """

    assert torch.is_tensor(X_tensor)
    assert X_tensor.shape[1] == dataset.length_of_sequence
    assert X_tensor.shape[2] == dataset.num_of_features

    cate_one_hot_index_dict = {}  # the indexes of one-hot categorical features in Xs

    if dataset.input_dimension == dataset.num_of_features:
        # if one-hot encoding is the same as non-one-hot encoding in the dataset
        return X_tensor, cate_one_hot_index_dict

    X_final = X_tensor[:, :, :dataset.cate_index_start]  # copy continuous features

    # convert categorical features to one-hot
    for cate_index in range(dataset.cate_index_start, len(dataset.feature_names)):
        feature_name = dataset.feature_names[cate_index]
        num_unique_values = dataset.cate_num_classes_dict[feature_name]

        if X_cate_is_standardized:
            # one-hot converter only works with integer values.
            # If categorical features are standardized, change them back to unstandardized
            X_unstandarized = (X_tensor[:, :, cate_index] * dataset.X_stds[cate_index] +
                               dataset.X_means[cate_index]).round()
        else:
            X_unstandarized = X_tensor[:, :, cate_index]

        one_hot_tensor = F.one_hot(X_unstandarized.to(torch.int64), num_classes=num_unique_values)

        cate_one_hot_index_dict[feature_name] = \
            [i for i in range(X_final.shape[-1], X_final.shape[-1] + num_unique_values)]

        X_final = torch.cat((X_final, one_hot_tensor), dim=-1)

    if X_final.shape != (X_tensor.shape[0], dataset.length_of_sequence, dataset.input_dimension):
        raise ValueError("What??? X_final.shape: ", X_final.shape)

    return X_final, cate_one_hot_index_dict


def undo_one_hot(X, dataset):
    """
    For one sample, copy the continuous features and convert one-hot categorical features back to regular form.

    Use this function before plotting.
    """

    assert torch.is_tensor(X)
    assert X.shape == (dataset.length_of_sequence, dataset.input_dimension)

    if dataset.input_dimension == dataset.num_of_features:
        # if one-hot encoding is the same as non-one-hot encoding in the dataset
        return X

    new_X = torch.zeros((dataset.length_of_sequence, dataset.num_of_features))

    # for each original time-series feature
    for f_index in range(dataset.num_of_features):
        if f_index < dataset.cate_index_start:  # for continuous features
            new_X[:, f_index] = X[:, f_index]
        else:  # for categorical features
            for name, indexes in dataset.cate_one_hot_index_dict.items():
                if name in dataset.feature_names[f_index]:  # `feature_name` may be: `feature_name` + (Intervened)
                    # invert one-hot back to original form
                    new_X[:, f_index] = X[:, indexes].argmax(dim=-1)
                    break

    assert new_X.shape == (dataset.length_of_sequence, dataset.num_of_features)

    return new_X


def undo_all_one_hot(X, dataset):
    """
    `undo_one_hot()` for all Xs.
    """

    assert torch.is_tensor(X)
    assert X.shape[1] == dataset.length_of_sequence
    assert X.shape[2] == dataset.input_dimension

    if dataset.input_dimension == dataset.num_of_features:
        # if one-hot encoding is the same as non-one-hot encoding in the dataset
        return X

    X_list_not_one_hot = []
    for i in range(X.shape[0]):
        Xi_not_one_hot = undo_one_hot(X[i], dataset)
        X_list_not_one_hot.append(Xi_not_one_hot.unsqueeze(dim=0))

    Xs_not_one_hot = torch.cat(X_list_not_one_hot, dim=0)
    return Xs_not_one_hot


def plot_X(axis, X, feature_names, title, xlabel, ylabel, dataset, FIG_FONT_SIZE=20, decimals=3):
    """
    Plot the multivariate time series (non-one-hot format)
    """
    if X.shape[-1] != dataset.num_of_features:
        raise ValueError("X should not be one-hot encoded. X={}".format(X))

    mc_list = create_color_maker_combo()

    if dataset.num_of_features > len(mc_list):
        raise ValueError("Add more colour-maker combinations. num_of_features={}. len(mc_list)={}."
                         .format(dataset.num_of_features, len(mc_list)))

    # for each original time-series feature
    for f_index in range(dataset.num_of_features):
        mkr, col = mc_list[f_index]
        axis.plot(X[:, f_index], label=feature_names[f_index] + " ({})"
                  .format(str(round(dataset.feature_proximity_weights[:, f_index].item(), decimals)))
        if feature_names is not None else None, marker=mkr, color=col)

    axis.set_xlabel(xlabel, fontsize=FIG_FONT_SIZE)
    axis.set_ylabel(ylabel, fontsize=FIG_FONT_SIZE)

    axis.tick_params(axis='both', labelsize=FIG_FONT_SIZE)

    axis.set_title(title, fontsize=FIG_FONT_SIZE)


def find_LSTM_best_validation_accuracy(root, result_folder_prefix, result_file_name, key_name):
    """
    Find the best validation accuracy among LSTM training results
    """

    best_acc = 0
    best_acc_configs = []

    # find all the subdirectories and files
    for file1 in os.listdir(root):

        if file1.startswith(result_folder_prefix):

            file1_full = os.path.join(root, file1)

            # if it is a directory
            if os.path.isdir(file1_full):
                print(file1_full)

                # find all the subdirectories and files
                for file2 in os.listdir(file1_full):

                    if file2 == result_file_name:
                        file2_full = os.path.join(file1_full, file2)

                        f = open(file2_full, 'r')
                        lines = f.readlines()

                        print(lines)

                        for i, line in enumerate(lines):
                            if key_name in line:
                                accuracy = float(lines[i + 1])  # result is on the next line
                                print("accuracy: ", accuracy)

                                if accuracy > best_acc:
                                    best_acc = accuracy
                                    best_acc_configs = [file1]

                                elif accuracy == best_acc:
                                    best_acc_configs.append(file1)

                                else:
                                    pass

    print("\nDone.\n")
    print("best_acc: ", best_acc)
    print("best_acc_configs: ", best_acc_configs)


def make_model_prediction(prediction_model, x, dataset):
    """
    :return: the probability of being class 1
    """
    assert torch.is_tensor(x)
    assert x.shape == (1, dataset.length_of_sequence, dataset.input_dimension)

    if prediction_model.prediction_model_type == Prediction_Model_Types.LSTM:

        y_pred = prediction_model(x)  # tensor (-1, 1)

    elif prediction_model.prediction_model_type == Prediction_Model_Types.Rule:

        state_not_one_hot = undo_one_hot(x.squeeze(0), dataset)

        y_pred = prediction_model.predict(state_not_one_hot.unsqueeze(0)).reshape(-1, 1)  # tensor (-1, 1)

    elif prediction_model.prediction_model_type == Prediction_Model_Types.RandomForest or \
            prediction_model.prediction_model_type == Prediction_Model_Types.KNN:

        state_not_one_hot = undo_one_hot(x.squeeze(0), dataset)

        y_pred_probs = prediction_model.predict_proba(state_not_one_hot.reshape(1, -1))
        assert y_pred_probs.shape[-1] == 2
        y_pred = torch.tensor(y_pred_probs[:, 1]).reshape(-1, 1)  # tensor (-1, 1)

    else:
        raise ValueError("what??? prediction_model_type={}".format(prediction_model.prediction_model_type))

    return y_pred


def compute_DPP_diversity(unique_CFEs, dataset):
    """
    Computes the DPP of a matrix.

    Bigger distance ==> Bigger DPP diversity
    """

    number_of_CFEs = len(unique_CFEs)

    det_entries = []
    for i in range(number_of_CFEs):
        for j in range(number_of_CFEs):
            proximity = compute_proximity(torch.tensor(unique_CFEs[i]), torch.tensor(unique_CFEs[j]), dataset, True)
            det_temp_entry = 1.0 / (1.0 + proximity)
            if i == j:
                det_temp_entry = det_temp_entry + 0.0001
            det_entries.append(det_temp_entry)

    det_entries = torch.tensor(det_entries).reshape((number_of_CFEs, number_of_CFEs))
    dpp_diversity = torch.linalg.det(det_entries)

    return dpp_diversity.item()


def compute_validity(generated_CFE_list, target_class, prediction_model, dataset):
    """
    Validity measures the ratio of the counterfactuals that actually have the desired class label to the total number
    of counterfactuals generated. Higher validity is preferable.
    """

    generated_CFEs = torch.tensor(np.array(generated_CFE_list))

    assert generated_CFEs.shape[1] == dataset.length_of_sequence
    assert generated_CFEs.shape[2] == dataset.input_dimension

    total_count = 0.0
    valid_count = 0.0

    for _, cfe in enumerate(generated_CFEs):
        pred = make_model_prediction(prediction_model, cfe.unsqueeze(0), dataset)

        total_count += 1.0
        if pred.round() == target_class:
            valid_count += 1.0

    validity = valid_count / total_count

    return validity


def compute_proximity(x1, x2, dataset, force_equal_weights=False):
    """
    Compute the proximity between 2 Xs
    """
    if x1.shape != (dataset.length_of_sequence, dataset.input_dimension) \
            or x2.shape != (dataset.length_of_sequence, dataset.input_dimension):
        # make sure they are in the one-hot encoding
        raise ValueError("What??? x1.shape={}, x2.shape={}".format(x1.shape, x2.shape))

    x1_original_format = undo_one_hot(x1, dataset)
    x2_original_format = undo_one_hot(x2, dataset)

    x1_continuous = x1_original_format[:, :dataset.cate_index_start]
    x2_continuous = x2_original_format[:, :dataset.cate_index_start]

    x1_categorical = x1_original_format[:, dataset.cate_index_start:]
    x2_categorical = x2_original_format[:, dataset.cate_index_start:]

    proximity_weights_continuous = dataset.feature_proximity_weights[:, :dataset.cate_index_start]
    proximity_weights_categorical = dataset.feature_proximity_weights[:, dataset.cate_index_start:]

    if force_equal_weights:  # for DPP diversity computation
        proximity_weights_continuous = torch.ones_like(proximity_weights_continuous)
        proximity_weights_categorical = torch.ones_like(proximity_weights_categorical)

    # Compute the L1-norm between X1 and X2
    # proximity computation for continuous and categorical features should be different, see DiCE

    L1_continuous_values = (x1_continuous - x2_continuous) * proximity_weights_continuous
    L0_categorical_values = (x1_categorical - x2_categorical) * proximity_weights_categorical

    # L1-norm for continuous features
    proximity_continuous = torch.linalg.norm(L1_continuous_values, ord=1).item()

    # L0-norm for categorical features
    proximity_categorical = torch.linalg.norm(L0_categorical_values.flatten(), ord=0).item()

    proximity = proximity_continuous + proximity_categorical

    return proximity


def compute_average_proximity(original_X, generated_CFE_list, dataset):
    """
    Compute the average proximity between the original X and each generated CFEs
    """

    generated_CFEs = torch.tensor(np.array(generated_CFE_list))

    assert generated_CFEs.shape[1] == dataset.length_of_sequence
    assert generated_CFEs.shape[2] == dataset.input_dimension

    count = 0.0
    proximity_total = 0.0

    for _, cfe in enumerate(generated_CFEs):
        count += 1.0
        proximity = compute_proximity(original_X, cfe, dataset)
        proximity_total += proximity

    average_proximity = proximity_total / count

    return average_proximity


def compute_sparsity(x1, x2, dataset):
    """
    Compute the sparsity between 2 Xs
    """

    if x1.shape != (dataset.length_of_sequence, dataset.input_dimension) \
            or x2.shape != (dataset.length_of_sequence, dataset.input_dimension):
        # make sure they are in the one-hot encoding
        raise ValueError("What??? x1.shape={}, x2.shape={}".format(x1.shape, x2.shape))

    x1_original_format = undo_one_hot(x1, dataset)
    x2_original_format = undo_one_hot(x2, dataset)

    # Compute the L0-norm between X1 and X2
    sparsity = torch.linalg.norm((x1_original_format - x2_original_format).flatten(), ord=0).item()  # L0-norm

    return sparsity


def compute_average_sparsity(original_X, generated_CFE_list, dataset):
    """
    Compute the average sparsity between the original X and each generated CFEs
    """

    generated_CFEs = torch.tensor(np.array(generated_CFE_list))

    assert generated_CFEs.shape[1] == dataset.length_of_sequence
    assert generated_CFEs.shape[2] == dataset.input_dimension

    count = 0.0
    sparsity_total = 0.0

    for _, cfe in enumerate(generated_CFEs):
        count += 1.0
        sparsity = compute_sparsity(original_X, cfe, dataset)
        sparsity_total += sparsity

    average_sparsity = sparsity_total / count

    return average_sparsity


def compute_plausibility(x_sample, dataset, lof=None, X_train=None):
    """
    Following previous works, we use Local Outlier Factor (LOF) to measure plausibility.
    """

    assert x_sample.shape == (dataset.length_of_sequence, dataset.input_dimension)

    if lof is None:
        lof = train_LOF(X_train, dataset)

    x_sample_original_format = undo_one_hot(x_sample, dataset).unsqueeze(0) \
        .reshape(-1, dataset.length_of_sequence * dataset.num_of_features)

    # Returns -1 for anomalies/outliers and +1 for inliers.
    plausibility = lof.predict(x_sample_original_format)

    if plausibility != Plausibility.implausible and plausibility != Plausibility.plausible:
        raise ValueError("what??? plausibility={}".format(plausibility))

    return plausibility.item()


def train_LOF(n_neighbors, X_one_hot, dataset):
    assert X_one_hot.shape[1] == dataset.length_of_sequence
    assert X_one_hot.shape[2] == dataset.input_dimension

    X_not_one_hot = undo_all_one_hot(X_one_hot, dataset)

    X_train_flat = X_not_one_hot.reshape(-1, dataset.length_of_sequence * dataset.num_of_features)

    lof = LocalOutlierFactor(n_neighbors=n_neighbors, novelty=True, metric='euclidean')
    lof.fit(X_train_flat)

    return lof


def make_LOF(X_one_hot, dataset, result_summary_file):
    """
    Find the highest `n_neighbors` that does not return `implausible`.
    Then, use it to train a Local Outlier Factor (LOF) to measure plausibility.
    """

    assert X_one_hot.shape[1] == dataset.length_of_sequence
    assert X_one_hot.shape[2] == dataset.input_dimension

    max_n_neighbors = int(np.sqrt(len(X_one_hot)))

    found = False
    for n_neighbors in range(1, max_n_neighbors + 1):

        lof = train_LOF(n_neighbors, X_one_hot, dataset)

        for x in X_one_hot:
            plausibility = compute_plausibility(x, dataset, lof)

            if plausibility != Plausibility.plausible:
                final_n_neighbors = n_neighbors - 1

                # the computed `plausibility` should be `plausible` for all training data points.
                print("plausibility={} for a training data with n_neighbors={}. So, final_n_neighbors={}."
                      .format(plausibility, n_neighbors, final_n_neighbors))

                found = True
                break

        if found:
            break

    if not found:
        final_n_neighbors = max_n_neighbors

    output_str = "\n`n_neighbors` used for LOF: {}\n"
    print(output_str.format(final_n_neighbors))
    append_to_file(result_summary_file, output_str.format(final_n_neighbors))

    return train_LOF(final_n_neighbors, X_one_hot, dataset)


def compute_average_margin_difference(generated_CFE_list, target_boundary, prediction_model, dataset):
    """
    Compute the margin difference defined in LatentCF++ (Equation 3): the difference between the prediction and
    the target Y decision boundary.
    """

    generated_CFEs = torch.tensor(np.array(generated_CFE_list))

    assert generated_CFEs.shape[1] == dataset.length_of_sequence
    assert generated_CFEs.shape[2] == dataset.input_dimension

    count = 0.0
    total_margin_diff = 0.0

    for _, cfe in enumerate(generated_CFEs):
        pred = make_model_prediction(prediction_model, cfe.unsqueeze(0), dataset)

        count += 1.0

        margin_diff = pred.item() - target_boundary
        total_margin_diff += margin_diff

    average_margin_diff = total_margin_diff / count

    return average_margin_diff

def read_all_result_files(result_folder_root):
    total_elapsed_seconds = 0
    count_invalid_X = 0
    count_CFE_found = 0
    count_valid_CFE_found = 0
    count_plausible = 0
    count_plausible_valid = 0
    proximity_total = 0
    sparsity_total = 0
    proximity_valid_total = 0
    sparsity_valid_total = 0

    index_folders = next(os.walk(result_folder_root))[1]
    for i_folder in index_folders:
        result_file = os.path.join(result_folder_root, i_folder, "0result_current_X.txt")
        with open(result_file, 'r') as f:
            count_invalid_X += 1
            duplicate_counter_time_elapsed = 0
            for line in f:
                if ":" in line:
                    key = line.split(": ")[0]
                    value = line.split(": ")[1]

                    if key == "Time Elapsed":
                        duplicate_counter_time_elapsed += 1

                        x = time.strptime(value.strip(), '%H:%M:%S.%f')
                        elapsed_seconds = datetime.timedelta(
                            hours=x.tm_hour, minutes=x.tm_min,
                            seconds=x.tm_sec).total_seconds()
                        total_elapsed_seconds += elapsed_seconds

                        if duplicate_counter_time_elapsed > 1:
                            raise ValueError(f"Duplicate result detected at: {result_file}")

                    elif key == "Validity":
                        count_CFE_found += 1
                        if value.strip() == "True":
                            valid_cfe_found = True
                        elif value.strip() == "False":
                            valid_cfe_found = False
                        else:
                            raise ValueError(f"Unexpected Validity value: {value}")

                        if valid_cfe_found:
                            count_valid_CFE_found += 1

                    elif key == "Validity (without Brute Force Search)":
                        pass

                    elif key == "Proximity":
                        proximity_total += float(value)

                        if valid_cfe_found:
                            proximity_valid_total += float(value)

                    elif key == "Sparsity":
                        sparsity_total += float(value)

                        if valid_cfe_found:
                            sparsity_valid_total += float(value)

                    elif key == "Plausibility":
                        plausibility = float(value)
                        if plausibility == Plausibility.plausible:
                            count_plausible += 1
                            if valid_cfe_found:
                                count_plausible_valid += 1
                        elif plausibility == Plausibility.implausible:
                            pass
                        else:
                            raise ValueError("what??? plausibility={}".format(plausibility))

                    else:
                        raise ValueError(f"Unexpected line in {result_file}: {line}")

    return {
        "total_elapsed_seconds": total_elapsed_seconds,
        "count_invalid_X": count_invalid_X,
        "count_CFE_found": count_CFE_found,
        "count_valid_CFE_found": count_valid_CFE_found,
        "count_plausible": count_plausible,
        "count_plausible_valid": count_plausible_valid,
        "proximity_total": proximity_total,
        "sparsity_total": sparsity_total,
        "proximity_valid_total": proximity_valid_total,
        "sparsity_valid_total": sparsity_valid_total,
    }



def log_result_summary(result_summary_file, total_start_time, count_invalid_X, count_CFE_found, count_valid_CFE_found,
                       count_plausible, count_plausible_valid, proximity_total, sparsity_total, proximity_valid_total,
                       sparsity_valid_total, total_elapsed_seconds=None):
    output_str = "\nNumber of invalid Xs: {}\n"
    print(output_str.format(count_invalid_X))
    append_to_file(result_summary_file, output_str.format(count_invalid_X))

    output_str = "\nSuccess Rate: Percentage of (count_valid_CFE_found / count_invalid_X): {}%\n" \
        .format((count_valid_CFE_found / count_invalid_X * 100) if count_invalid_X != 0 else "N/A")
    print(output_str)
    append_to_file(result_summary_file, output_str)

    output_str = "\nValidity: percentage of (count_valid_CFE_found / count_CFE_found): {}%\n" \
        .format((count_valid_CFE_found / count_CFE_found * 100) if count_CFE_found != 0 else "N/A")
    print(output_str)
    append_to_file(result_summary_file, output_str)

    output_str = "\nAverage Proximity: (proximity_total / count_CFE_found): {}\n" \
        .format((proximity_total / count_CFE_found) if count_CFE_found != 0 else "N/A")
    print(output_str)
    append_to_file(result_summary_file, output_str)

    output_str = "\nAverage Proximity: (proximity_valid_total / count_valid_CFE_found): {}\n" \
        .format((proximity_valid_total / count_valid_CFE_found) if count_valid_CFE_found != 0 else "N/A")
    print(output_str)
    append_to_file(result_summary_file, output_str)

    output_str = "\nAverage Sparsity: (sparsity_total / count_CFE_found): {}\n" \
        .format((sparsity_total / count_CFE_found) if count_CFE_found != 0 else "N/A")
    print(output_str)
    append_to_file(result_summary_file, output_str)

    output_str = "\nAverage Sparsity: (sparsity_valid_total / count_valid_CFE_found): {}\n" \
        .format((sparsity_valid_total / count_valid_CFE_found) if count_valid_CFE_found != 0 else "N/A")
    print(output_str)
    append_to_file(result_summary_file, output_str)

    output_str = "\nPercentage of (count_plausible / count_CFE_found): {}%\n" \
        .format((count_plausible / count_CFE_found * 100) if count_CFE_found != 0 else "N/A")
    print(output_str)
    append_to_file(result_summary_file, output_str)

    output_str = "\nPlausibility: Percentage of (count_plausible_valid / count_valid_CFE_found): {}%\n" \
        .format((count_plausible_valid / count_valid_CFE_found * 100) if count_valid_CFE_found != 0 else "N/A")
    print(output_str)
    append_to_file(result_summary_file, output_str)

    if total_start_time is not None and total_elapsed_seconds is not None:
        raise ValueError("One of total_start_time and total_elapsed_seconds should be None.")
    elif total_start_time is not None:
        total_end_time = time.time()
        total_elapsed = str(datetime.timedelta(seconds=total_end_time - total_start_time))
        print("Total Time Elapsed: {}".format(total_elapsed))
        append_to_file(result_summary_file, "\nTotal Time Elapsed: {}\n".format(total_elapsed))
    elif total_elapsed_seconds is not None:
        print("Total Seconds Elapsed: {}".format(total_elapsed_seconds))
        append_to_file(result_summary_file, "\nTotal Seconds Elapsed: {}\n".format(total_elapsed_seconds))
    else:
        raise ValueError("Both total_start_time and total_elapsed_seconds are None.")


def get_default_prediction_model_version(prediction_model_version, dataset_name):

    if prediction_model_version is None:

        if dataset_name == Dataset_Names.life_expectancy:
            prediction_model_version = ""  # add path to your prediction model here

        elif dataset_name == Dataset_Names.natops:
            prediction_model_version = ""

        elif dataset_name == Dataset_Names.heartbeat:
            prediction_model_version = ""

        elif dataset_name == Dataset_Names.racket_sports:
            prediction_model_version = ""

        elif dataset_name == Dataset_Names.basic_motions:
            prediction_model_version = ""

        elif dataset_name == Dataset_Names.ering:
            prediction_model_version = ""

        elif dataset_name == Dataset_Names.japanese_vowels:
            prediction_model_version = ""

        elif dataset_name == Dataset_Names.libras:
            prediction_model_version = ""

        elif dataset_name == Dataset_Names.PEMS_SF:
            prediction_model_version = ""

        else:
            raise ValueError("what??? dataset_name={}".format(dataset_name))

    return prediction_model_version
