# Copyright (c) 2025-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
##################################################################

import argparse
import os
import sys
import time
import warnings
from datetime import timedelta
from pathlib import Path
import numpy as np
import torch
import utils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from Dataset_BasicMotions import Dataset_BasicMotions
from Dataset_ERing import Dataset_ERing
from Dataset_Heartbeat import Dataset_Heartbeat
from Dataset_JapaneseVowels import Dataset_JapaneseVowels
from Dataset_Libras import Dataset_Libras
from Dataset_Life_Expectancy import Dataset_Life_Expectancy
from Dataset_NATOPS import Dataset_NATOPS
from Dataset_PEMS_SF import Dataset_PEMS_SF
from Dataset_RacketSports import Dataset_RacketSports
from main_LSTM import LSTM_Model
from torch.distributions import Categorical, Normal
from collections import deque
from SCM_utils import SCM
from main_Rule_Model import Rule_Model

WHOLE_FIG_W = 15
WHOLE_FIG_H = 15
FIG_FONT_SIZE = 20


class PolicyNetwork(nn.Module):
    def __init__(self, dataset):
        super(PolicyNetwork, self).__init__()

        num_neurons1 = 1000
        num_neurons2 = 100

        self.fc1 = nn.Linear(dataset.length_of_sequence * dataset.input_dimension, num_neurons1)
        self.fc2 = nn.Linear(num_neurons1, num_neurons2)

        self.feature_fc_last = nn.Linear(num_neurons2, dataset.num_of_features)
        self.timestep_fc_last = nn.Linear(num_neurons2, dataset.length_of_sequence)

        self.cont_feats_mean_fc_last = []
        self.cont_feats_std_fc_last = []
        for i in range(dataset.num_of_features - len(dataset.cate_one_hot_index_dict.keys())):
            self.cont_feats_mean_fc_last.append(nn.Linear(num_neurons2, 1))
            self.cont_feats_std_fc_last.append(nn.Linear(num_neurons2, 1))

        self.cate_feats_fc_last = {}
        for name, indexes in dataset.cate_one_hot_index_dict.items():
            number_of_classes = len(indexes)
            self.cate_feats_fc_last[name] = nn.Linear(num_neurons2, number_of_classes)

        self.saved_log_probs_feature = []
        self.saved_log_probs_timestep = []
        self.saved_log_probs_value = []
        self.rewards = []

    def check_for_continue(self):
        assert len(self.saved_log_probs_feature) == 0
        assert len(self.saved_log_probs_timestep) == 0
        assert len(self.saved_log_probs_value) == 0
        assert len(self.rewards) == 0

    def forward(self, input):
        x = self.fc1(input)
        x = F.relu(x)
        x = self.fc2(x)

        action_feature = self.feature_fc_last(x)  # feature index in the original X, not in the one-hot encoding format
        action_timestep = self.timestep_fc_last(x)

        action_feature_probs = F.softmax(action_feature, dim=1)
        action_timestep_probs = F.softmax(action_timestep, dim=1)

        cont_feats_means = []
        for fc_last in self.cont_feats_mean_fc_last:
            cont_feats_means.append(fc_last(x))

        cont_feats_stds = []
        for fc_last in self.cont_feats_std_fc_last:
            cont_feats_stds.append(F.softplus(fc_last(x)))  # add a tiny constant if 0

        cate_feats_probs = {}
        for name, fc_last in self.cate_feats_fc_last.items():
            cate_feat_action = fc_last(x)
            cate_feats_probs[name] = F.softmax(cate_feat_action, dim=1)

        # set the intervention probabilities on immutable features to 0
        actionable_feature_probs = action_feature_probs.clone()
        for nonact_i in nonactionable_feature_indexes:
            actionable_feature_probs[0, nonact_i] = 0.0

        return actionable_feature_probs, action_timestep_probs, cont_feats_means, cont_feats_stds, cate_feats_probs


def select_action(state):
    assert state.shape == (1, dataset.length_of_sequence * dataset.input_dimension)

    action_feature_probs, action_timestep_probs, cont_feats_means, cont_feats_stds, cate_feats_probs \
        = rl_model(state)

    assert action_feature_probs.shape[-1] == dataset.num_of_features
    assert action_timestep_probs.shape[-1] == dataset.length_of_sequence
    assert len(cont_feats_means) + len(cate_feats_probs) == dataset.num_of_features
    assert len(cont_feats_stds) + len(cate_feats_probs) == dataset.num_of_features

    # select a feature
    dist_action_feature = Categorical(action_feature_probs)
    action_feature = dist_action_feature.sample()

    # select a time step
    dist_action_timestep = Categorical(action_timestep_probs)
    action_timestep = dist_action_timestep.sample()

    global feature_entropy, timestep_entropy
    feature_entropy = dist_action_feature.entropy()  # compute the entropy of feature selection
    timestep_entropy = dist_action_timestep.entropy()  # compute the entropy of time step selection

    rl_model.saved_log_probs_feature.append(dist_action_feature.log_prob(action_feature))
    rl_model.saved_log_probs_timestep.append(dist_action_timestep.log_prob(action_timestep))

    if dataset.cate_index_start is not None and action_feature >= dataset.cate_index_start:
        intervene_on_categorical = True
    else:
        intervene_on_categorical = False

    if intervene_on_categorical:
        dist_action_value = None
        for name, cate_feat_probs in cate_feats_probs.items():
            if dataset.feature_names[action_feature] == name:
                dist_action_value = Categorical(cate_feat_probs)
                break

        action_value = dist_action_value.sample()
        rl_model.saved_log_probs_value.append(dist_action_value.log_prob(action_value).reshape((1, 1)))
    else:
        action_mean = cont_feats_means[action_feature]
        action_std = cont_feats_stds[action_feature]
        dist_action_value = Normal(action_mean, action_std)
        action_value = dist_action_value.sample()
        rl_model.saved_log_probs_value.append(dist_action_value.log_prob(action_value))

    return intervene_on_categorical, action_feature.item(), action_timestep.item(), action_value.item()


def finish_episode():
    R = 0
    policy_loss = []
    returns = deque()
    for r in rl_model.rewards[::-1]:
        R = r + gamma * R
        returns.appendleft(R)
    returns = torch.tensor(returns)

    for log_prob_feature, log_prob_timestep, log_prob_value, R in zip(rl_model.saved_log_probs_feature,
                                                                      rl_model.saved_log_probs_timestep,
                                                                      rl_model.saved_log_probs_value, returns):
        policy_loss.append(-(log_prob_feature + log_prob_timestep + log_prob_value) * R)
    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    del rl_model.rewards[:]
    del rl_model.saved_log_probs_feature[:]
    del rl_model.saved_log_probs_timestep[:]
    del rl_model.saved_log_probs_value[:]


def env_reward(state):
    y_pred = utils.make_model_prediction(prediction_model, state, dataset)

    # Use 0 or 1 as the prediction loss instead of prediction probability, since 0.6 or 0.7 does not make a
    # difference as long as they are rounded to the same target class
    if y_pred.round() == target_class:
        prediction_loss = 1.0

        proximity_loss = utils.compute_proximity(torch.reshape(original_x, state[0].shape), state[0], dataset)

    elif y_pred.round() == invalid_class:
        prediction_loss = 0.0

        # Consider the proximity loss only when the prediction is the desired one. So that in the case when the reward
        # is sparse, the RL agent is not trained towards producing the original X, which is invalid, by minimizing the
        # proximity.
        proximity_loss = 0.0

    else:
        raise ValueError("what??? y_pred.round()={}".format(y_pred.round()))

    # maximum entropy RL
    reward = prediction_loss - lambda_proximity * proximity_loss + lambda_entropy * (
            feature_entropy + timestep_entropy).item()

    # We may want the prediction loss to be bigger than the weighted proximity loss, because the importance of
    # prediction should be bigger than proximity.
    if prediction_loss < lambda_proximity * proximity_loss:
        warnings.warn("prediction_loss {} < lambda_proximity {} * proximity_loss {}. "
                      "Consider reducing `lambda_proximity`."
                      .format(prediction_loss, lambda_proximity, proximity_loss))

    return reward, y_pred


def env_step(intervene_on_categorical, state, action_feature, action_timestep, action_value):
    """
    Return done=True if a valid CFE is found.
    """

    new_state = state.clone()

    if not intervene_on_categorical:  # to intervene on continuous features
        if intervention_type == "point":  # point intervention
            new_state[0, action_timestep, action_feature] = state[0, action_timestep, action_feature] + action_value

        elif intervention_type == "drifting":  # drifting
            new_state[0, action_timestep:, action_feature] = state[0, action_timestep:, action_feature] + action_value

        else:
            raise ValueError("What??? intervention_type={}".format(intervention_type))

        # max and min values of a feature
        if feature_extreme_values:
            new_state = apply_range(new_state, action_timestep, action_feature)

    else:  # to intervene on categorical features
        for name, indexes in dataset.cate_one_hot_index_dict.items():
            if name == dataset.feature_names[action_feature]:
                if intervention_type == "point":  # point intervention
                    raise ValueError("To be implemented.")
                elif intervention_type == "drifting":  # drifting
                    new_state[0, action_timestep:, indexes] = 0.0
                    new_state[0, action_timestep:, indexes[action_value]] = 1.0
                else:
                    raise ValueError("What??? intervention_type={}".format(intervention_type))
                break

    new_state = apply_SCM(action_timestep, action_feature, new_state)

    reward, y_pred = env_reward(new_state)

    if y_pred.round() == target_class:
        done = True
    elif y_pred.round() == invalid_class:
        done = False
    else:
        raise ValueError("what??? y_pred.round()={}".format(y_pred.round()))

    return new_state, reward, done


def apply_range(X, action_timestep, action_feature):
    if dataset.cate_index_start is not None and action_feature >= dataset.cate_index_start:
        raise ValueError("It should be for continuous feature only. action_feature={}".format(action_feature))

    X[0, action_timestep:, action_feature] = torch.clamp(X[0, action_timestep:, action_feature],
                                                         min=X_min[action_feature],
                                                         max=X_max[action_feature])

    return X


def apply_SCM(intervened_timestep, intervened_fea_index, X):
    # create SCM instance
    # start with a new SCM every time, to make sure no overdue values in the SCM, which may cause problem
    scm = SCM()

    # add current variable values X to SCMs
    for variable_name in dataset.feature_names:
        if variable_name not in dataset.categorical_features:  # add continuous features
            if intervention_type == "point":
                scm.add_variable(variable_name, X[0, intervened_timestep,
                dataset.feature_names.index(variable_name)])
            elif intervention_type == "drifting":
                scm.add_variable(variable_name, X[0, intervened_timestep:,
                                                dataset.feature_names.index(variable_name)])
            else:
                raise ValueError("What??? intervention_type={}".format(intervention_type))

    check_one_hot_encoding_validity(X[0])
    X_plot_format = utils.undo_one_hot(X[0], dataset)
    for variable_name, indexes in dataset.cate_one_hot_index_dict.items():  # add categorical features

        if intervention_type == "point":  # point intervention
            raise ValueError("To be implemented.")
        elif intervention_type == "drifting":  # drifting
            plot_format_values = X_plot_format[intervened_timestep:, dataset.feature_names.index(variable_name)]
        else:
            raise ValueError("What??? intervention_type={}".format(intervention_type))

        scm.add_variable(variable_name, plot_format_values)

    # add SCM functions, which overwrite the current variable values added above
    for child, function in scm_functions.items():
        scm.add_variable(child, function)

    # Do the intervention on the variable that is changed by the RL agent.
    # This breaks the connection between its parents, so SCM functions won't change the intervened value.
    intervened_fea_name = dataset.feature_names[intervened_fea_index]

    if intervened_fea_name in dataset.categorical_features:  # intervened on categorical variable
        if intervention_type == "point":  # point intervention
            raise ValueError("To be implemented.")
        elif intervention_type == "drifting":  # drifting
            intervened_value = X_plot_format[intervened_timestep:, intervened_fea_index]
        else:
            raise ValueError("What??? intervention_type={}".format(intervention_type))
    else:  # intervened on continuous variable
        if intervention_type == "point":
            intervened_value = X[0, intervened_timestep, intervened_fea_index]
        elif intervention_type == "drifting":
            intervened_value = X[0, intervened_timestep:, intervened_fea_index]
        else:
            raise ValueError("What??? intervention_type={}".format(intervention_type))

    scm.do(intervened_fea_name, intervened_value)

    new_X = X.clone()

    # apply the SCM functions
    for child, _ in scm_functions.items():

        # Currently, the SCM is applied only after intervention on an ancestor of a pre-defined SCM function.
        # It means, without intervention on an ancestor of a pre-defined SCM function, the original data related to the SCM won't change.
        # For A->B, if A is not intervened, then B will not change
        # For A->B->C, if A is intervened, B and C will be updated
        if intervened_fea_name not in scm.get_ancestors(child):
            continue

        sampled_child = scm.sample(child)

        if child in dataset.categorical_features:  # categorical child variable
            for name, indexes in dataset.cate_one_hot_index_dict.items():
                if name == child:
                    sampled_child_one_hot = F.one_hot(sampled_child.to(torch.int64), num_classes=len(indexes))

                    if intervention_type == "point":
                        new_X[0, intervened_timestep, indexes] = sampled_child_one_hot.to(torch.float)
                    elif intervention_type == "drifting":
                        new_X[0, intervened_timestep:, indexes] = sampled_child_one_hot.to(torch.float)
                    else:
                        raise ValueError("What??? intervention_type={}".format(intervention_type))
                    break

        else:  # continuous child variable
            if intervention_type == "point":
                new_X[0, intervened_timestep, dataset.feature_names.index(child)] = sampled_child
            elif intervention_type == "drifting":
                new_X[0, intervened_timestep:, dataset.feature_names.index(child)] = sampled_child
            else:
                raise ValueError("What??? intervention_type={}".format(intervention_type))

    return new_X


def plot_training_info(xlabel, ylabel, png_file, *data):
    """
    :param data: each argument is in the form of: {"legend" : {episode_1:[data points], episode_2:[data points], ...}}
    """

    for pair in data:
        legend = list(pair.keys())[0]
        x = list(list(pair.values())[0].keys())
        y = list(list(pair.values())[0].values())
        plt.plot(x, y, marker='.', linestyle='none', label=legend)  # plot dots instead of lines

    plt.legend()
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(os.path.join(result_folder_index, png_file))
    plt.close('all')


def check_one_hot_encoding_validity(X):
    """
    check the correctness of one-hot encodings
    """
    for _, indexes in dataset.cate_one_hot_index_dict.items():
        valid_zeros = torch.all((X[:, indexes] == 0).sum(-1) == (len(indexes) - 1))
        valid_ones = torch.all((X[:, indexes] == 1).sum(-1) == 1)
        if not (valid_zeros and valid_ones):
            raise ValueError("Something is wrong with the one-hot encoding: {}".format(X[:, indexes]))


def start_reinforce_searching(s0):
    print("Start RL reinforce...")

    assert s0.shape == (1, dataset.length_of_sequence, dataset.input_dimension)
    assert utils.make_model_prediction(prediction_model, s0, dataset).round() == invalid_class
    assert max_episodes % training_plot_interval == 0  # so that the info at the end of training is saved

    # Length: number of episodes
    mean_rewards_invalid_records = {}  # for intervened states that are invalid
    mean_rewards_valid_records = {}  # for intervened states that are valid (i.e. CFEs)
    proximity_weighted_invalid_records = {}
    proximity_weighted_valid_records = {}
    sparsity_invalid_records = {}
    sparsity_valid_records = {}
    validity_records = {}
    plausibility_valid_records = {}
    interventions_to_valid_records = {}
    interv_feature_name_records = {}
    # Length: number of episodes * number of interventions
    action_feature_records = []
    action_timestep_records = []
    action_value_records = []
    generated_CFEs = {}
    for i_episode in range(1, max_episodes + 1):
        if debug:
            print("i_episode: ", i_episode)
        state = torch.flatten(s0, start_dim=1)
        interv_feature_names = dataset.feature_names.copy()
        action_features = []
        action_timesteps = []
        action_values = []
        done = False
        for t in range(1, max_number_of_interventions + 1):
            """
            Action and state transition logic:
            - if the index of a continuous feature is selected, then X[sequence, feature] = X[sequence, feature] + action_value
            - if the index of a categorical feature is selected, use a categorical distribution to generate a value for its one-hot encoding
            """

            state = torch.flatten(state, start_dim=1)
            intervene_on_categorical, action_feature, action_timestep, action_value = select_action(state)

            if action_feature in nonactionable_feature_indexes:
                raise ValueError("Feature {} is immutable.".format(dataset.feature_names[action_feature]))

            action_features.append(action_feature)
            action_timesteps.append(action_timestep)
            action_values.append(action_value)

            state_orginal_shape = torch.reshape(state, shape=s0.shape)

            if t == 1:
                if not torch.all(state_orginal_shape == s0):
                    raise ValueError("What??? state_orginal_shape == s0: ".format(state_orginal_shape == s0))

            state, reward, done = env_step(intervene_on_categorical, state_orginal_shape, action_feature,
                                           action_timestep, action_value)

            assert state.shape == (1, dataset.length_of_sequence, dataset.input_dimension)

            check_one_hot_encoding_validity(state[0])

            rl_model.rewards.append(reward)

            interv_feature_names[action_feature] = dataset.feature_names[action_feature] + " (Intervened)"

            if done:

                if debug:
                    print("i_episode: ", i_episode)
                    print("Valid CFE found!!!")
                    print("Found this CFE after {} interventions!".format(t))

                CFE = state.squeeze(0)

                if unique_CFEs and np.any([np.all(CFE.numpy() == generated_CFE)
                                           for generated_CFE in list(generated_CFEs.values())]):
                    if debug:
                        print("The same CFE was generated already. Skip plotting.")
                    break

                generated_CFEs[i_episode] = CFE.numpy()

                interv_feature_name_records[i_episode] = interv_feature_names

                if debug:
                    # Plot the valid CFE
                    plot_X_CFE_and_Diff(s0, CFE, interv_feature_names, t, 'generated_valid_CFE_{}'.format(i_episode))

                break

        current_ep_mean_rewards = sum(rl_model.rewards) / len(rl_model.rewards)

        finish_episode()

        proximity_weighted = utils.compute_proximity(torch.reshape(s0, state[0].shape), state[0], dataset)
        sparsity = utils.compute_sparsity(torch.reshape(s0, state[0].shape), state[0], dataset)

        y_pred = utils.make_model_prediction(prediction_model, state, dataset).round().item()
        validity_records[i_episode] = y_pred

        if done:
            mean_rewards_valid_records[i_episode] = current_ep_mean_rewards
            proximity_weighted_valid_records[i_episode] = proximity_weighted
            sparsity_valid_records[i_episode] = sparsity

            plausibility_valid = utils.compute_plausibility(state[0], dataset, lof)
            plausibility_valid_records[i_episode] = plausibility_valid

            interventions_to_valid_records[i_episode] = t
        else:
            mean_rewards_invalid_records[i_episode] = current_ep_mean_rewards
            proximity_weighted_invalid_records[i_episode] = proximity_weighted
            sparsity_invalid_records[i_episode] = sparsity

        action_feature_records.extend(action_features)
        action_timestep_records.extend(action_timesteps)
        action_value_records.extend(action_values)

        sorted_proximity_weighted_valid_records = dict(sorted(proximity_weighted_valid_records.items(),
                                                              key=lambda pair: pair[1]))

        if debug and (i_episode % training_plot_interval == 0):  # save training process

            # plots
            plot_training_info("Episodes", "Mean Reward", "0training_mean_rewards.png",
                               {"Valid": mean_rewards_valid_records},
                               {"Invalid": mean_rewards_invalid_records})
            plot_training_info("Episodes", "Proximity (Weighted)", "0training_Proximity_weighted.png",
                               {"Valid": proximity_weighted_valid_records},
                               {"Invalid": proximity_weighted_invalid_records})
            plot_training_info("Episodes", "Sparsity", "0training_Sparsity.png",
                               {"Valid": sparsity_valid_records},
                               {"Invalid": sparsity_invalid_records})
            plot_training_info("Episodes", "Validity", "0training_Validity.png", {"All": validity_records})
            plot_training_info("Episodes", "Number of interventions to generate a valid CFE",
                               "0training_Number_of_Interventions_to_Valid.png",
                               {"Valid": interventions_to_valid_records})
            plot_training_info("Episodes", "Plausibility of Valid CFEs (inliers: {}, outliers: {})"
                               .format(utils.Plausibility.plausible, utils.Plausibility.implausible),
                               "0training_Plausibility_Valid.png", {"Valid": plausibility_valid_records})

            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(WHOLE_FIG_W * 2, WHOLE_FIG_H))
            ax1.plot(action_timestep_records, marker='.', linestyle='none', label="intervened_timesteps")
            ax1.yaxis.set_major_locator(MaxNLocator(integer=True))  # force y ticks to be integers
            ax1.set_xlabel("Episodes * Number of Interventions")
            ax1.set_ylabel("Values")
            ax1.set_title("intervened_timesteps")
            ax1.legend()
            ax2.plot(action_feature_records, marker='.', linestyle='none', label="intervened_features")
            ax2.yaxis.set_major_locator(MaxNLocator(integer=True))  # force y ticks to be integers
            ax2.set_xlabel("Episodes * Number of Interventions")
            ax2.set_ylabel("Values")
            ax2.set_title("intervened_features")
            ax2.legend()
            ax3.plot(action_value_records, marker='.', linestyle='none', label="intervened_values")
            ax3.set_xlabel("Episodes * Number of Interventions")
            ax3.set_ylabel("Values")
            ax3.set_title("intervened_values")
            ax3.legend()
            fig.savefig(os.path.join(result_folder_index, "0training_interventions.png"))
            plt.close('all')

            # save text files
            sorted_interventions_to_valid_records = dict(sorted(interventions_to_valid_records.items(),
                                                                key=lambda pair: pair[1]))
            utils.overwrite_file(os.path.join(result_folder_index, "0training_Number_of_Interventions_to_Valid.txt"),
                                 "Number of interventions to generate a valid CFE: \n"
                                 + str(sorted_interventions_to_valid_records).replace(", ", "\n"))

            utils.overwrite_file(os.path.join(result_folder_index, "0training_Proximity_Weighted.txt"),
                                 "Proximity (weighted) of valid CFEs: \n"
                                 + str(sorted_proximity_weighted_valid_records).replace(", ", "\n"))

            sorted_sparsity_valid_records = dict(sorted(sparsity_valid_records.items(),
                                                        key=lambda pair: pair[1]))
            utils.overwrite_file(os.path.join(result_folder_index, "0training_Sparsity.txt"),
                                 "Sparsity of valid CFEs: \n" + str(sorted_sparsity_valid_records).replace(", ", "\n"))

    if len(generated_CFEs) == 0:  # if no valid CFEs are found
        best_CFE = None

    else:
        best_CFE = None
        # get the CFE with the smallest weighted proximity
        sorted_episodes_by_proximity = list(sorted_proximity_weighted_valid_records.keys())

        for i in range(len(sorted_episodes_by_proximity)):
            episode = sorted_episodes_by_proximity[i]
            if plausible_CFEs:
                # the first in `sorted_episodes_by_proximity` that is also plausible is the best
                plausible = plausibility_valid_records[episode]
                if plausible == utils.Plausibility.plausible:
                    best_CFE = generated_CFEs.get(episode)
                    break
            else:
                # the first in `sorted_episodes_by_proximity` is the best
                best_CFE = generated_CFEs.get(episode)
                break

        top_N = 10  # plot the top-N CFEs
        i = 0
        while i < top_N:

            if i >= len(sorted_episodes_by_proximity):
                print("all CFEs are plotted.")
                break

            episode = sorted_episodes_by_proximity[i]

            if plausible_CFEs:
                plausible = plausibility_valid_records[episode]
                if plausible == utils.Plausibility.implausible:  # do not plot if not plausible when plausible is required
                    top_N += 1  # so that the total number of plotted CFEs is still `top_N`
                    i += 1
                    continue

            CFE_np = generated_CFEs.get(episode)
            # this can be None if the corresponding CFE is not unique and has been seen before. In this case,
            # `sorted_episodes_by_proximity` has it but `generated_CFEs` does not have it.
            if CFE_np is None:
                top_N += 1  # so that the total number of plotted CFEs is still `top_N`
                i += 1
                continue

            if debug:
                plot_X_CFE_and_Diff(s0, torch.tensor(CFE_np), interv_feature_name_records[episode],
                                    interventions_to_valid_records[episode], 'Top_{}_CFE_{}'.format(i, episode))

            i += 1

    return generated_CFEs, sorted_proximity_weighted_valid_records, best_CFE


def plot_X_CFE_and_Diff(x, CFE, interv_feature_names, num_of_interventions, file_name):
    """
    Plot the original X, the generated CFE, their difference.

    Print metric on the plot.

    Save the generated CFE as a csv file.
    """

    for i, n in enumerate(interv_feature_names):

        # this column name is too long in the legend
        if n == "People_using_at_least_basic_drinking_water_services":
            interv_feature_names[i] = "People_drinking_water_services"

        elif n == "People_using_at_least_basic_drinking_water_services (Intervened)":
            interv_feature_names[i] = "People_drinking_water_services (Intervened)"
            # raise Exception("check and see if the name is still too long.")

    weighted_proximity = utils.compute_proximity(torch.reshape(x, CFE.shape), CFE, dataset)
    unweighted_proximity = utils.compute_proximity(torch.reshape(x, CFE.shape), CFE, dataset, True)
    sparsity = utils.compute_sparsity(torch.reshape(x, CFE.shape), CFE, dataset)
    plausibility_valid = utils.compute_plausibility(CFE, dataset, lof)
    CFE_pred = utils.make_model_prediction(prediction_model, CFE.unsqueeze(0), dataset).item()

    # the subplots have the same scale in X
    fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, figsize=(WHOLE_FIG_W * 2, WHOLE_FIG_H / 2), sharex=True)
    ax5.set_visible(False)  # use its space for legend
    fig.tight_layout(pad=6.0)  # add space between subplots

    utils.plot_X(ax1, original_x_plot_format, None, "Original User Input", "Time Step", "Value", dataset, FIG_FONT_SIZE)

    CFE_plot_format = utils.undo_one_hot(CFE, dataset)
    utils.plot_X(ax2, CFE_plot_format, interv_feature_names, "Generated CFE", "Time Step", "Value", dataset,
                 FIG_FONT_SIZE)
    np.savetxt(os.path.join(result_folder_index, file_name + ".csv"), CFE_plot_format.numpy(), delimiter=',')

    difference = CFE_plot_format - original_x_plot_format
    changed_in_x = original_x_plot_format.clone()
    changed_in_x[difference == 0] = float('nan')  # do not plot features that are not changed
    utils.plot_X(ax3, changed_in_x, None, "Features Changed (Original)", "Time Step", "Value", dataset, FIG_FONT_SIZE)

    changed_in_cfe = CFE_plot_format.clone()
    changed_in_cfe[difference == 0] = float('nan')  # do not plot features that are not changed
    utils.plot_X(ax4, changed_in_cfe, None, "Features Changed (CFE)", "Time Step", "Value", dataset, FIG_FONT_SIZE)

    difference[difference == 0] = float('nan')  # do not plot 0 differences

    cfe_result_str = "Found this CFE after {} interventions! \nPrediction={}" \
                     "\nProximity (weighted) = {} \nProximity (unweighted) = {} \nSparsity = {}" \
                     "\nPlausibility (inliers: {}, outliers: {}) = {} " \
        .format(num_of_interventions, CFE_pred, weighted_proximity, unweighted_proximity, sparsity,
                utils.Plausibility.plausible, utils.Plausibility.implausible, plausibility_valid)
    utils.append_to_file(os.path.join(result_folder_index, file_name + ".txt"), cfe_result_str)

    # make subplots the same scale
    y_axis_min = min([ax1.get_ylim()[0], ax2.get_ylim()[0], ax3.get_ylim()[0], ax4.get_ylim()[0]])
    y_axis_max = max([ax1.get_ylim()[1], ax2.get_ylim()[1], ax3.get_ylim()[1], ax4.get_ylim()[1]])
    ax1.set_ylim([y_axis_min, y_axis_max])
    ax2.set_ylim([y_axis_min, y_axis_max])
    ax3.set_ylim([y_axis_min, y_axis_max])
    ax4.set_ylim([y_axis_min, y_axis_max])

    fig.legend(loc='center left', bbox_to_anchor=(0.79, 0.5), fontsize=FIG_FONT_SIZE / 1.1)

    fig.savefig(os.path.join(result_folder_index, file_name + ".png"))
    plt.close('all')

    verify_immutable_features_not_changed(immutable_feature_indexes, difference)


def make_feature_category(immutable_features, nonactionable_features, scm_functions):
    """
    Make categories for features: "immutable", "non-actionable", "regular"

    immutable: cannot be intervened nor changed by SCMs
    non-actionable: cannot be intervened, but may or may not be changed by SCMs
    regular: anything else
    """

    # immutable features are also non-actionable.
    nonactionable_features.extend(immutable_features)
    nonactionable_features = list(set(nonactionable_features))

    # immutable features cannot be changed by SCM functions
    for immutable_feature in immutable_features:
        scm_functions.pop(immutable_feature, None)  # remove the SCM function that changes the immutable feature

    nonactionable_feature_indexes = [dataset.feature_names.index(nonact_f) for nonact_f in nonactionable_features]

    immutable_feature_indexes = [dataset.feature_names.index(immu_f) for immu_f in immutable_features]

    return immutable_feature_indexes, nonactionable_feature_indexes, scm_functions


def verify_immutable_features_not_changed(immutable_feature_indexes, X_plot_difference):
    for immutable_feature_index in immutable_feature_indexes:
        if not torch.all(X_plot_difference[:, immutable_feature_index].isnan()):
            raise ValueError("Immutable feature {} is changed???"
                             .format(dataset.feature_names[immutable_feature_index]))


if __name__ == "__main__":

    print(sys.argv)

    total_start_time = time.time()

    parser = argparse.ArgumentParser()
    parser.add_argument('--debug', required=False, type=int, default=0)
    parser.add_argument('--skip_finished', required=False, type=int, default=0)
    parser.add_argument('--random_seed', required=False, type=int, default=1)
    parser.add_argument('--result_folder', required=False, type=str, default="temp_results_RL")
    parser.add_argument('--result_folder_suffix', required=False, type=str, default="")
    parser.add_argument('--dataset_name', required=False, type=str, default=utils.Dataset_Names.life_expectancy,
                        choices=[utils.Dataset_Names.life_expectancy, utils.Dataset_Names.natops,
                                 utils.Dataset_Names.heartbeat, utils.Dataset_Names.racket_sports,
                                 utils.Dataset_Names.basic_motions, utils.Dataset_Names.ering,
                                 utils.Dataset_Names.japanese_vowels, utils.Dataset_Names.libras,
                                 utils.Dataset_Names.PEMS_SF])
    parser.add_argument('--prediction_model_type', required=False, type=str, default=utils.Prediction_Model_Types.LSTM,
                        choices=[utils.Prediction_Model_Types.LSTM, utils.Prediction_Model_Types.KNN,
                                 utils.Prediction_Model_Types.RandomForest, utils.Prediction_Model_Types.Rule])
    parser.add_argument('--prediction_model_version', required=False, type=str, default=None)
    parser.add_argument('--lambda_proximity', required=False, type=float, default=0.001)
    parser.add_argument('--lambda_entropy', required=False, type=float, default=0.0)
    parser.add_argument('--max_number_of_interventions', required=False, type=int, default=100)
    parser.add_argument('--gamma', required=False, type=float, default=0.99)
    parser.add_argument('--intervention_type', required=False, type=str, default="drifting",
                        choices=["drifting", "point"])
    parser.add_argument('--feature_extreme_values', required=False, type=int, default=1)
    parser.add_argument('--feature_proximity_weights', nargs="+", required=False, type=float, default=None,
                        help="[c1, c2, ...] for each feature. A small value means the corresponding feature weights "
                             "less in proximity computation: the user prefers to change the feature. Default=None: "
                             "meaning [1, 1, ...]")
    parser.add_argument('--learning_rate', required=False, type=float, default=0.0001)
    parser.add_argument('--weight_decay', required=False, type=float, default=0.0)
    parser.add_argument('--unique_CFEs', required=False, type=int, default=1,
                        help="saving all generated CFEs vs saving only unique CFEs.")
    parser.add_argument('--plausible_CFEs', required=False, type=int, default=0,
                        help="saving all generated CFEs vs saving only plausible CFEs.")
    parser.add_argument('--max_episodes', required=False, type=int, default=10000)
    parser.add_argument('--training_plot_interval', required=False, type=int, default=1000)

    parser.add_argument('--rl_continue', required=False, type=int, default=0, choices=[0, 1])

    args = parser.parse_args()

    debug = args.debug
    print("debug: ", debug)

    skip_finished = args.skip_finished
    print("skip_finished: ", skip_finished)

    random_seed = args.random_seed
    print("random_seed: ", random_seed)

    dataset_name = args.dataset_name
    print("dataset_name: ", dataset_name)

    prediction_model_type = args.prediction_model_type
    print("prediction_model_type: ", prediction_model_type)

    prediction_model_version = args.prediction_model_version
    print("prediction_model_version: ", prediction_model_version)

    lambda_proximity = args.lambda_proximity
    print("lambda_proximity: ", lambda_proximity)

    feature_entropy = None
    timestep_entropy = None
    lambda_entropy = args.lambda_entropy
    print("lambda_entropy: ", lambda_entropy)

    max_number_of_interventions = args.max_number_of_interventions
    print("max_number_of_interventions: ", max_number_of_interventions)

    gamma = args.gamma
    print("gamma: ", gamma)

    intervention_type = args.intervention_type
    print("intervention_type: ", intervention_type)

    feature_extreme_values = args.feature_extreme_values
    print("feature_extreme_values: ", feature_extreme_values)

    feature_proximity_weights = args.feature_proximity_weights
    print("feature_proximity_weights: ", feature_proximity_weights)

    learning_rate = args.learning_rate
    print("learning_rate: ", learning_rate)

    weight_decay = args.weight_decay
    print("weight_decay: ", weight_decay)

    result_folder_root = args.result_folder + args.result_folder_suffix
    print("result_folder_root: ", result_folder_root)

    unique_CFEs = args.unique_CFEs
    print("unique_CFEs: ", unique_CFEs)

    plausible_CFEs = args.plausible_CFEs
    print("plausible_CFEs: ", plausible_CFEs)

    max_episodes = args.max_episodes
    print("max_episodes: ", max_episodes)

    training_plot_interval = args.training_plot_interval
    print("training_plot_interval: ", training_plot_interval)

    rl_continue = args.rl_continue
    print("rl_continue: ", rl_continue)

    config_file = os.path.join(result_folder_root, "configurations.txt")
    result_summary_file = os.path.join(result_folder_root, "0result_summary.txt")

    os.makedirs(result_folder_root, exist_ok=True)

    utils.log_config(config_file, args, result_summary_file)

    utils.set_random_seed(random_seed)

    prediction_model_version = utils.get_default_prediction_model_version(prediction_model_version, dataset_name)

    ################################## Loading dataset ##################################
    data_path_root = os.getcwd()

    if dataset_name == utils.Dataset_Names.life_expectancy:

        dataset = Dataset_Life_Expectancy(data_path_root, None, feature_proximity_weights)

        X_one_hot, y, X_max, X_min = dataset.data_process_LE_one_hot()

        X_not_one_hot = utils.undo_all_one_hot(X_one_hot, dataset)

        assert dataset.cate_index_start == dataset.num_of_features - len(dataset.categorical_features)

        # label: Life Expectancy in 2015
        invalid_class = 0
        target_class = 1

        scm_functions = {}

        # assume these features are non-actionable (may or may not be immutable)
        nonactionable_features = ['Population', 'Obesity_among_adults', 'Beer_consumption_per_capita']

        # assume these features are immutable (and therefore, non-actionable)
        immutable_features = ['Continent']

    elif dataset_name == utils.Dataset_Names.natops:
        dataset = Dataset_NATOPS(data_path_root, None, feature_proximity_weights)
        X_one_hot, y, X_max, X_min = dataset.load_dataset("test")
        X_not_one_hot = X_one_hot.clone()

        invalid_class = 0
        target_class = 1

        scm_functions = {}

        # assume these features are non-actionable (may or may not be immutable)
        nonactionable_features = []

        # assume these features are immutable (and therefore, non-actionable)
        immutable_features = []

    elif dataset_name == utils.Dataset_Names.heartbeat:
        dataset = Dataset_Heartbeat(data_path_root, None, feature_proximity_weights)
        X_one_hot, y, X_max, X_min = dataset.load_dataset("test")

        X_not_one_hot = X_one_hot.clone()

        invalid_class = 0
        target_class = 1

        scm_functions = {}

        # assume these features are non-actionable (may or may not be immutable)
        nonactionable_features = []

        # assume these features are immutable (and therefore, non-actionable)
        immutable_features = []

    elif dataset_name == utils.Dataset_Names.racket_sports:
        dataset = Dataset_RacketSports(data_path_root, None, feature_proximity_weights)
        X_one_hot, y, X_max, X_min = dataset.load_dataset("test")

        X_not_one_hot = X_one_hot.clone()

        invalid_class = 0
        target_class = 1

        scm_functions = {}

        # assume these features are non-actionable (may or may not be immutable)
        nonactionable_features = []

        # assume these features are immutable (and therefore, non-actionable)
        immutable_features = []


    elif dataset_name == utils.Dataset_Names.basic_motions:
        dataset = Dataset_BasicMotions(data_path_root, None, feature_proximity_weights)
        X_one_hot, y, X_max, X_min = dataset.load_dataset("test")
        X_not_one_hot = X_one_hot.clone()

        invalid_class = 0
        target_class = 1

        scm_functions = {}

        # assume these features are non-actionable (may or may not be immutable)
        nonactionable_features = []

        # assume these features are immutable (and therefore, non-actionable)
        immutable_features = []

    elif dataset_name == utils.Dataset_Names.ering:
        dataset = Dataset_ERing(data_path_root, None, feature_proximity_weights)
        X_one_hot, y, X_max, X_min = dataset.load_dataset("test")
        X_not_one_hot = X_one_hot.clone()

        invalid_class = 0
        target_class = 1

        scm_functions = {}

        # assume these features are non-actionable (may or may not be immutable)
        nonactionable_features = []

        # assume these features are immutable (and therefore, non-actionable)
        immutable_features = []

    elif dataset_name == utils.Dataset_Names.japanese_vowels:
        dataset = Dataset_JapaneseVowels(data_path_root, None, feature_proximity_weights)
        X_one_hot, y, X_max, X_min = dataset.load_dataset("test")
        X_not_one_hot = X_one_hot.clone()

        invalid_class = 0
        target_class = 1

        scm_functions = {}

        # assume these features are non-actionable (may or may not be immutable)
        nonactionable_features = []

        # assume these features are immutable (and therefore, non-actionable)
        immutable_features = []

    elif dataset_name == utils.Dataset_Names.libras:
        dataset = Dataset_Libras(data_path_root, None, feature_proximity_weights)
        X_one_hot, y, X_max, X_min = dataset.load_dataset("test")
        X_not_one_hot = X_one_hot.clone()

        invalid_class = 0
        target_class = 1

        scm_functions = {}

        # assume these features are non-actionable (may or may not be immutable)
        nonactionable_features = []

        # assume these features are immutable (and therefore, non-actionable)
        immutable_features = []

    elif dataset_name == utils.Dataset_Names.PEMS_SF:
        dataset = Dataset_PEMS_SF(data_path_root, None, feature_proximity_weights)
        X_one_hot, y, X_max, X_min = dataset.load_dataset("test")
        X_not_one_hot = X_one_hot.clone()

        invalid_class = 0
        target_class = 1

        scm_functions = {}

        # assume these features are non-actionable (may or may not be immutable)
        nonactionable_features = []

        # assume these features are immutable (and therefore, non-actionable)
        immutable_features = []

    else:
        raise ValueError("what??? dataset_name={}".format(dataset_name))

    immutable_feature_indexes, nonactionable_feature_indexes, scm_functions = \
        make_feature_category(immutable_features, nonactionable_features, scm_functions)
    del nonactionable_features
    del immutable_features

    ###########################################################################

    lof = utils.make_LOF(X_one_hot, dataset, result_summary_file)

    print("feature_names: ", dataset.feature_names)

    if prediction_model_type == utils.Prediction_Model_Types.LSTM:

        # Load the saved LSTM model
        prediction_model = LSTM_Model(dataset.input_dimension, dataset.output_dimension)
        prediction_model.load_state_dict(torch.load(prediction_model_version))
        prediction_model.eval()

        y_pred = prediction_model(X_one_hot)  # tensor (N, 1)

    elif prediction_model_type == utils.Prediction_Model_Types.Rule:
        prediction_model = Rule_Model(dataset, prediction_model_version, dataset_name)

        y_pred = prediction_model.predict(X_not_one_hot).reshape(-1, 1)

    elif prediction_model_type == utils.Prediction_Model_Types.RandomForest or \
            prediction_model_type == utils.Prediction_Model_Types.KNN:

        X_2d = X_not_one_hot.reshape((-1, dataset.length_of_sequence * dataset.num_of_features)).numpy()
        Y_1d = y.reshape((y.shape[0],)).numpy()

        if prediction_model_type == utils.Prediction_Model_Types.RandomForest:
            prediction_model = RandomForestClassifier(n_estimators=100, class_weight='balanced')
            setattr(prediction_model, 'prediction_model_type', utils.Prediction_Model_Types.RandomForest)

        elif prediction_model_type == utils.Prediction_Model_Types.KNN:
            prediction_model = KNeighborsClassifier(n_neighbors=int(np.sqrt(len(X_2d))))
            setattr(prediction_model, 'prediction_model_type', utils.Prediction_Model_Types.KNN)

        else:
            raise ValueError("what??? prediction_model_type={}".format(prediction_model_type))

        prediction_model.fit(X_2d, Y_1d)

        y_pred_probs = prediction_model.predict_proba(X_2d)  # assume predicting on training dataset is okay
        assert y_pred_probs.shape[-1] == 2
        y_pred = torch.tensor(y_pred_probs[:, 1]).reshape(-1, 1)  # the probability of being 1

    else:
        raise ValueError("what??? prediction_model_type={}".format(prediction_model_type))

    print("Portion of prediction of 1s: ", (y_pred.round().sum() / len(y_pred)).item())

    if rl_continue == 1:
        rl_model = PolicyNetwork(dataset)
        optimizer = optim.Adam(rl_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # find an instance where the prediciton is `invalid_class`
    count_invalid_X = 0.0
    count_CFE_found = 0.0
    count_valid_CFE_found = 0.0
    count_plausible = 0.0
    count_plausible_valid = 0.0
    proximity_total = 0.0
    sparsity_total = 0.0
    proximity_valid_total = 0.0
    sparsity_valid_total = 0.0
    total_elapsed_seconds = None
    for y_index in range(len(y_pred)):
        y_pred_current = y_pred[y_index]
        if y_pred_current.round() == invalid_class:

            current_start_time = time.time()

            count_invalid_X += 1
            print("\ninvalid X index: ", y_index)
            original_x = X_one_hot[y_index, :, :]  # shape: (sequence length, one-hot encoding dimension)

            result_folder_index = os.path.join(result_folder_root, "X_Index_" + str(y_index))
            os.makedirs(result_folder_index, exist_ok=True)
            result_current_file = os.path.join(result_folder_index, "0result_current_X.txt")

            if skip_finished:
                if Path(result_current_file).is_file():
                    print(f"Result file exists: {result_current_file}. Skip.")
                    continue
                else:
                    print(f"Running for: {result_current_file}...")

            original_x_plot_format = utils.undo_one_hot(original_x, dataset)

            if debug:
                # Plot the original x
                fig, ax = plt.subplots(figsize=(WHOLE_FIG_W, WHOLE_FIG_H / 2))
                utils.plot_X(ax, original_x_plot_format, dataset.feature_names, "Original User Input", "Time Step",
                             "Value", dataset, FIG_FONT_SIZE)
                fig.text(x=0.5, y=0.5, s="Prediction={}".format(y_pred_current.item()))
                fig.savefig(os.path.join(result_folder_index, 'original_x.png'))
                plt.close('all')

            if rl_continue == 0:
                rl_model = PolicyNetwork(dataset)
                optimizer = optim.Adam(rl_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
            else:
                rl_model.check_for_continue()

            try:
                _, _, best_CFE = start_reinforce_searching(original_x.unsqueeze(0))
            except ValueError as e:
                message = str(e)

                if "nan" in message:
                    print(message)
                    best_CFE = None
                else:
                    raise ValueError(e)

            current_end_time = time.time()
            current_elapsed = str(timedelta(seconds=current_end_time - current_start_time))
            print("Time Elapsed: {}".format(current_elapsed))
            utils.append_to_file(result_current_file, "\nTime Elapsed: {}\n".format(current_elapsed))

            if best_CFE is None:
                # if no valid CFEs are found, do not update the counts
                utils.append_to_file(result_current_file, "\nNo valid CFEs are found.\n")
                continue

            count_CFE_found += 1

            cfe_tensor = torch.tensor(best_CFE)

            cfe_prediction = utils.make_model_prediction(prediction_model, cfe_tensor.unsqueeze(0), dataset)

            ###################################### Compute Metrics ######################################

            if cfe_prediction.round() == target_class:
                valid_cfe_found = True
            elif cfe_prediction.round() == invalid_class:
                valid_cfe_found = False
            else:
                raise ValueError("what??? cfe_prediction={}".format(cfe_prediction))

            print("valid_cfe_found: ", valid_cfe_found)
            if valid_cfe_found:
                count_valid_CFE_found += 1

            proximity = utils.compute_proximity(original_x, cfe_tensor, dataset, force_equal_weights=True)
            proximity_total += proximity

            sparsity = utils.compute_sparsity(original_x, cfe_tensor, dataset)
            sparsity_total += sparsity

            if valid_cfe_found:
                proximity_valid_total += proximity
                sparsity_valid_total += sparsity

            plausibility = utils.compute_plausibility(cfe_tensor, dataset, lof)
            if plausibility == utils.Plausibility.plausible:
                count_plausible += 1
                if valid_cfe_found:
                    count_plausible_valid += 1
            elif plausibility == utils.Plausibility.implausible:
                pass
            else:
                raise ValueError("what??? plausibility={}".format(plausibility))

            utils.append_to_file(result_current_file,
                                 "\nValidity: {}\n\nProximity: {}\n\nSparsity: {}\n\nPlausibility: {}\n"
                                 .format(valid_cfe_found, proximity, sparsity, plausibility))

    if skip_finished:
        print("All experiments finished.")
        parsed_results = utils.read_all_result_files(result_folder_root)

        count_invalid_X = parsed_results["count_invalid_X"]
        count_CFE_found = parsed_results["count_CFE_found"]
        count_valid_CFE_found = parsed_results["count_valid_CFE_found"]
        count_plausible = parsed_results["count_plausible"]
        count_plausible_valid = parsed_results["count_plausible_valid"]
        proximity_total = parsed_results["proximity_total"]
        sparsity_total = parsed_results["sparsity_total"]
        proximity_valid_total = parsed_results["proximity_valid_total"]
        sparsity_valid_total = parsed_results["sparsity_valid_total"]

        total_start_time = None
        total_elapsed_seconds = parsed_results["total_elapsed_seconds"]

    utils.log_result_summary(result_summary_file, total_start_time, count_invalid_X, count_CFE_found,
                         count_valid_CFE_found, count_plausible, count_plausible_valid, proximity_total,
                         sparsity_total, proximity_valid_total, sparsity_valid_total, total_elapsed_seconds)
