# Copyright (c) 2025-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
##################################################################

import os
import numpy as np
from aeon.datasets import load_from_tsfile
import torch


class Dataset_JapaneseVowels:
    def __init__(self, data_path_root, result_folder, feature_proximity_weights=None):
        self.data_path_root = data_path_root
        self.result_folder = result_folder
        self.length_of_sequence = 25
        self.num_of_features = 12
        self.categorical_features = []
        self.cate_num_classes_dict = {}
        self.input_dimension = self.num_of_features
        self.output_dimension = 1
        self.feature_names = ["feature_{}".format(i) for i in range(1, self.num_of_features + 1)]
        self.cate_index_start = None
        self.cate_one_hot_index_dict = {}
        self.X_means = None
        self.X_stds = None

        if feature_proximity_weights is None:
            self.feature_proximity_weights = torch.tensor([1.0 for _ in range(self.num_of_features)]).reshape(1, -1)
        else:
            self.feature_proximity_weights = torch.tensor(feature_proximity_weights).reshape(1, -1)

            if self.feature_proximity_weights.shape != (1, self.num_of_features):
                raise ValueError("self.feature_proximity_weights.shape {} != (1, self.num_of_features {})."
                                 .format(self.feature_proximity_weights.shape, self.num_of_features))

    def load_dataset(self, mode):
        if mode == "train":
            data_path = os.path.join(self.data_path_root, "data/multivariate_time_series/JapaneseVowels/JapaneseVowels_eq_TRAIN.ts")
        elif mode == "test":
            data_path = os.path.join(self.data_path_root, "data/multivariate_time_series/JapaneseVowels/JapaneseVowels_eq_TEST.ts")
        else:
            raise ValueError(f"invalid mode: {mode}.")

        if self.result_folder is not None:
            standardization_info_means = os.path.join(self.result_folder, "JapaneseVowels_means.csv")
            standardization_info_stds = os.path.join(self.result_folder, "JapaneseVowels_stds.csv")

        train_x, train_y_str = load_from_tsfile(data_path)
        x_np = np.transpose(train_x, (0, 2, 1))
        train_y = train_y_str.astype(float)

        y_np = np.empty((len(train_y), 1))
        for i in range(len(y_np)):
            y = train_y[i]

            if 1 <= y <= 4:
                y_value = 0
            elif 5 <= y <= 9:
                y_value = 1
            else:
                raise ValueError("What??? y={}".format(y))

            y_np[i] = y_value

        assert x_np.shape[1] == self.length_of_sequence
        assert x_np.shape[2] == self.num_of_features
        assert y_np.shape[1] == 1

        for i in range(self.num_of_features):
            number_of_unique_values = len(np.unique(x_np[:, :, i]))
            print("number_of_unique_values of feature {}: {}".format(i, number_of_unique_values))
            if number_of_unique_values >= 100:
                print("feature {} should be continuous.".format(i))
            else:
                raise ValueError("feature {} should be categorical.".format(i))

        # Start dataset standardization
        x_2d = x_np.reshape(-1, self.num_of_features)
        assert np.all(x_np[0, :, :] == x_2d[:self.length_of_sequence, :])
        assert np.all(x_np[1, :, :] == x_2d[self.length_of_sequence:self.length_of_sequence * 2, :])
        assert np.all(x_np[-1, :, :] == x_2d[-self.length_of_sequence:, :])

        self.X_means = x_2d.mean(axis=0)
        self.X_stds = x_2d.std(axis=0)
        print("X_means: \n", self.X_means)
        print("X_stds: \n", self.X_stds)
        if self.result_folder is not None:
            np.savetxt(standardization_info_means, self.X_means, delimiter=',')
            np.savetxt(standardization_info_stds, self.X_stds, delimiter=',')

        X_2d_stand = (x_2d - self.X_means) / self.X_stds

        X_max = X_2d_stand.max(axis=0)
        X_min = X_2d_stand.min(axis=0)

        assert self.X_means.shape == (self.num_of_features,)
        assert self.X_stds.shape == (self.num_of_features,)
        assert X_max.shape == (self.num_of_features,)
        assert X_min.shape == (self.num_of_features,)

        X_stand = X_2d_stand.reshape(-1, self.length_of_sequence, self.num_of_features)

        assert np.all(X_stand[0, :, :] == X_2d_stand[:self.length_of_sequence, :])
        assert np.all(X_stand[1, :, :] == X_2d_stand[self.length_of_sequence:self.length_of_sequence * 2, :])
        assert np.all(X_stand[-1, :, :] == X_2d_stand[-self.length_of_sequence:, :])

        X_stand_tensor = torch.tensor(X_stand, dtype=torch.float32)
        y_tensor = torch.tensor(y_np, dtype=torch.float32).reshape(-1, 1)
        X_max_tensor = torch.tensor(X_max)
        X_min_tensor = torch.tensor(X_min)

        assert len(y_tensor.unique()) == 2

        return X_stand_tensor, y_tensor, X_max_tensor, X_min_tensor
