# Copyright (c) 2025-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
##################################################################

import os
import numpy as np
import torch
from aeon.datasets import load_from_tsfile


def dset2npy(fp, dataset_name):
    read_dataset = fp[dataset_name]
    npy_arr = np.empty(read_dataset.shape, dtype=read_dataset.dtype)
    read_dataset.read_direct(npy_arr)
    return npy_arr


class Dataset_NATOPS:
    def __init__(self, data_path_root, result_folder, feature_proximity_weights=None):
        self.data_path_root = data_path_root
        self.result_folder = result_folder
        self.length_of_sequence = 51
        self.num_of_features = 24
        self.categorical_features = []
        self.cate_num_classes_dict = {}
        self.input_dimension = self.num_of_features
        self.output_dimension = 1
        self.feature_names = ["Hand tip left, X coordinate", "Hand tip left, Y coordinate",
                              "Hand tip left, Z coordinate", "Hand tip right, X coordinate",
                              "Hand tip right, Y coordinate", "Hand tip right, Z coordinate",
                              "Elbow left, X coordinate", "Elbow left, Y coordinate", "Elbow left, Z coordinate",
                              "Elbow right, X coordinate", "Elbow right, Y coordinate", "Elbow right, Z coordinate",
                              "Wrist left, X coordinate", "Wrist left, Y coordinate", "Wrist left, Z coordinate",
                              "Wrist right, X coordinate", "Wrist right, Y coordinate", "Wrist right, Z coordinate",
                              "Thumb left, X coordinate", "Thumb left, Y coordinate", "Thumb left, Z coordinate",
                              "Thumb right, X coordinate", "Thumb right, Y coordinate", "Thumb right, Z coordinate"]
        self.cate_index_start = None
        self.cate_one_hot_index_dict = {}
        self.X_means = None
        self.X_stds = None

        if feature_proximity_weights is None:
            self.feature_proximity_weights = torch.tensor([1.0 for _ in range(self.num_of_features)]).reshape(1, -1)
        else:
            self.feature_proximity_weights = torch.tensor(feature_proximity_weights).reshape(1, -1)

            if self.feature_proximity_weights.shape != (1, self.num_of_features):
                raise ValueError("self.feature_proximity_weights.shape {} != (1, self.num_of_features {})."
                                 .format(self.feature_proximity_weights.shape, self.num_of_features))

    def load_dataset(self, mode):
        if mode == "train":
            data_path = os.path.join(self.data_path_root, "data/multivariate_time_series/NATOPS/NATOPS_TRAIN.ts")
        elif mode == "test":
            data_path = os.path.join(self.data_path_root, "data/multivariate_time_series/NATOPS/NATOPS_TEST.ts")
        else:
            raise ValueError(f"invalid mode: {mode}.")

        if self.result_folder is not None:
            standardization_info_means = os.path.join(self.result_folder, "NATOPS_means.csv")
            standardization_info_stds = os.path.join(self.result_folder, "NATOPS_stds.csv")

        train_x, train_y_str = load_from_tsfile(data_path)
        x_np = np.transpose(train_x, (0, 2, 1))
        y_np = train_y_str.astype(float)

        for i in range(self.num_of_features):
            number_of_unique_values = len(np.unique(x_np[:, :, i]))
            print("number_of_unique_values of feature {}: {}".format(i, number_of_unique_values))
            if number_of_unique_values > len(x_np) * self.length_of_sequence * 0.8:
                print("feature {} should be continuous.".format(i))
            else:
                raise ValueError("feature {} should be categorical.".format(i))

        # make Y binary:
        y_np[y_np == 1] = 0
        y_np[y_np == 2] = 0
        y_np[y_np == 3] = 0
        y_np[y_np == 4] = 1
        y_np[y_np == 5] = 1
        y_np[y_np == 6] = 1

        # Start dataset standardization
        x_2d = x_np.reshape(-1, self.num_of_features)
        assert np.all(x_np[0, :, :] == x_2d[:self.length_of_sequence, :])
        assert np.all(x_np[1, :, :] == x_2d[self.length_of_sequence:self.length_of_sequence * 2, :])
        assert np.all(x_np[-1, :, :] == x_2d[-self.length_of_sequence:, :])

        self.X_means = x_2d.mean(axis=0)
        self.X_stds = x_2d.std(axis=0)
        print("X_means: \n", self.X_means)
        print("X_stds: \n", self.X_stds)
        if self.result_folder is not None:
            np.savetxt(standardization_info_means, self.X_means, delimiter=',')
            np.savetxt(standardization_info_stds, self.X_stds, delimiter=',')

        X_2d_stand = (x_2d - self.X_means) / self.X_stds

        X_max = X_2d_stand.max(axis=0)
        X_min = X_2d_stand.min(axis=0)

        assert self.X_means.shape == (self.num_of_features,)
        assert self.X_stds.shape == (self.num_of_features,)
        assert X_max.shape == (self.num_of_features,)
        assert X_min.shape == (self.num_of_features,)

        X_stand = X_2d_stand.reshape(-1, self.length_of_sequence, self.num_of_features)

        assert np.all(X_stand[0, :, :] == X_2d_stand[:self.length_of_sequence, :])
        assert np.all(X_stand[1, :, :] == X_2d_stand[self.length_of_sequence:self.length_of_sequence * 2, :])
        assert np.all(X_stand[-1, :, :] == X_2d_stand[-self.length_of_sequence:, :])

        X_stand_tensor = torch.tensor(X_stand)
        y_tensor = torch.tensor(y_np, dtype=torch.float32).reshape(-1, 1)
        X_max_tensor = torch.tensor(X_max)
        X_min_tensor = torch.tensor(X_min)

        assert len(y_tensor.unique()) == 2

        return X_stand_tensor.float(), y_tensor.float(), X_max_tensor.float(), X_min_tensor.float()
