# Copyright (c) 2025-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
##################################################################

import os
import torch
import utils
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder


class Dataset_Life_Expectancy:
    def __init__(self, data_path_root, result_folder, feature_proximity_weights=None):
        self.data_path_root = data_path_root
        self.length_of_sequence = 16  # years from 2000 to 2015
        self.num_of_features = 14  # number of time series
        self.categorical_features = ['Continent', 'Least_Developed']
        self.cate_num_classes_dict = {'Continent': 6, 'Least_Developed': 2}  # categorical feature name: # of classes
        self.result_folder = result_folder
        self.input_dimension = 20  # 12 continuous features + 2 categorical features with 6 classes and 2 classes
        self.output_dimension = 1
        self.feature_names = None
        self.cate_index_start = None
        self.cate_one_hot_index_dict = None
        self.X_means = None
        self.X_stds = None

        if feature_proximity_weights is None:
            self.feature_proximity_weights = torch.tensor([1.0 for _ in range(self.num_of_features)]).reshape(1, -1)
        else:
            self.feature_proximity_weights = torch.tensor(feature_proximity_weights).reshape(1, -1)

            if self.feature_proximity_weights.shape != (1, self.num_of_features):
                raise ValueError("self.feature_proximity_weights.shape {} != (1, self.num_of_features {})."
                                 .format(self.feature_proximity_weights.shape, self.num_of_features))

    def data_process_LE_one_hot(self):
        data_path = os.path.join(self.data_path_root,
                                 "data/multivariate_time_series/Life_Expectancy/Life_Expectancy_00_15.csv")

        if self.result_folder is not None:
            standardization_info_means = os.path.join(self.result_folder, "Life_Expectancy_means.csv")
            standardization_info_stds = os.path.join(self.result_folder, "Life_Expectancy_stds.csv")

        target_lift_expectancy = 75  # label=1 if lift_expectancy >= target_lift_expectancy, otherwise label=0

        # Read data
        data = pd.read_csv(data_path, header=0, delimiter=";")
        print("data.columns: ", data.columns.tolist())

        # for SCM, the variable name must be a valid python variable name
        formatted_column_names = [name.replace(" ", "_") for name in data.columns.tolist()]
        data.columns = formatted_column_names  # rename columns
        print("formatted data.columns: ", data.columns.tolist())

        Y_df = data[['Country', 'Year', 'Life_Expectancy']]
        X_df = data.drop(columns=['Country', 'Year', 'Life_Expectancy'])

        print(X_df.head())
        print(Y_df.head())

        Y_df_labels = Y_df.loc[Y_df['Year'] == 2015]  # only Life Expectancy in 2015 is used at the label

        # convert feature text values to numeric values
        for cate_feat in self.categorical_features:
            encoder = LabelEncoder()
            encoder.fit(X_df[cate_feat])
            X_df[cate_feat] = encoder.transform(X_df[cate_feat])
            categorical_mapping = dict(zip(encoder.classes_, encoder.transform(encoder.classes_)))
            print("mapping for {}: {}".format(cate_feat, categorical_mapping))

        # move categorical columns to the end
        for cate_feat in self.categorical_features:
            cate_feat_column_df = X_df.pop(cate_feat)
            X_df.insert(len(X_df.columns), cate_feat, cate_feat_column_df)

        # continuous features first, then categorical features. This is the index where categorical features start
        cate_index_start = min([X_df.columns.tolist().index(cate_feat) for cate_feat in self.categorical_features])
        self.cate_index_start = cate_index_start

        # convert label numerical values to binary values
        Y_df_labels.loc[Y_df_labels['Life_Expectancy'] < target_lift_expectancy, 'Life_Expectancy'] = 0
        Y_df_labels.loc[Y_df_labels['Life_Expectancy'] >= target_lift_expectancy, 'Life_Expectancy'] = 1

        print("Counter number of binary label values: \n", Y_df_labels['Life_Expectancy'].value_counts())

        # dataset standardization
        X_means = X_df.mean(axis=0)
        X_stds = X_df.std(axis=0)
        print("X_means: \n", X_means)
        print("X_stds: \n", X_stds)
        self.X_means = X_means.to_numpy()
        self.X_stds = X_stds.to_numpy()
        if self.result_folder is not None:
            np.savetxt(standardization_info_means, self.X_means, delimiter=',')
            np.savetxt(standardization_info_stds, self.X_stds, delimiter=',')

        X_df = (X_df - X_means) / X_stds

        print(X_df.shape)
        print(Y_df_labels.shape)

        feature_names = X_df.columns.tolist()
        self.feature_names = feature_names

        # convert to numpy
        X_array = X_df.to_numpy()
        Y_array = Y_df_labels.to_numpy()

        print(X_array.shape)
        print(Y_array.shape)

        # shape: (119, 16, 14)
        # (countries i.e. number of samples, years i.e. sequence length, features i.e number of time series)
        X_reshaped = X_array.reshape((-1, self.length_of_sequence, self.num_of_features))
        print(X_reshaped.shape)

        assert np.all(X_array[:16, :] == X_reshaped[0])
        assert np.all(X_array[16:32, :] == X_reshaped[1])
        assert np.all(X_array[-16:, :] == X_reshaped[-1])

        # remove columns: counter_name, year
        Y = np.array(Y_array[:, -1], dtype=np.float32)

        X_max = torch.tensor(X_array.max(axis=0))
        X_min = torch.tensor(X_array.min(axis=0))

        # convert to tensors
        X_tensor = torch.tensor(X_reshaped, dtype=torch.float32)
        Y_tensor = torch.tensor(Y, dtype=torch.float32).reshape(-1, 1)

        print(X_tensor.shape)
        print(Y_tensor.shape)

        X_cate_is_standardized = True
        X_final, cate_one_hot_index_dict = utils.convert_to_one_hot(X_tensor, self, X_cate_is_standardized)
        self.cate_one_hot_index_dict = cate_one_hot_index_dict

        # make sure `convert_to_one_hot()` and `undo_one_hot()` are consistent
        Xs_not_one_hot = utils.undo_all_one_hot(X_final, self)
        assert torch.all(Xs_not_one_hot[:, :, :cate_index_start] == X_tensor[:, :, :cate_index_start])
        assert torch.all(Xs_not_one_hot[:, :, cate_index_start:]
                         == (X_tensor[:, :, cate_index_start:] * X_stds.to_numpy()[cate_index_start:]
                             + X_means.to_numpy()[cate_index_start:]).round())

        assert len(Y_tensor.unique()) == 2

        return X_final, Y_tensor, X_max, X_min
