import numpy as np
from sklearn.preprocessing import StandardScaler

from momentfm.utils.data import load_from_tsfile


class ClassificationDataset:
    def __init__(self, dataset_folder: str = None, dataset_name: str = None, data_split="train"):
        """
        Parameters
        ----------
        data_split : str
            Split of the dataset, 'train', 'val' or 'test'.
        """

        self.seq_len = 512
        if dataset_folder == None and dataset_name == None:
            self.train_file_path_and_name = "../data/ECG5000_TRAIN.ts"
            self.test_file_path_and_name = "../data/ECG5000_TEST.ts"
        elif dataset_folder != None and dataset_name != None:
            self.train_file_path_and_name = f"{dataset_folder}/{dataset_name}/{dataset_name}_TRAIN.ts"
            self.test_file_path_and_name = f"{dataset_folder}/{dataset_name}/{dataset_name}_TEST.ts"
        else:
            raise ValueError("Either both dataset_folder and dataset_name should be None or both should be provided.")
        self.data_split = data_split  # 'train' or 'test'

        # Read data
        self._read_data()

    def _transform_labels(self, train_labels: np.ndarray, test_labels: np.ndarray):
        labels = np.unique(train_labels)  # Move the labels to {0, ..., L-1}
        transform = {}
        for i, l in enumerate(labels):
            transform[l] = i

        train_labels = np.vectorize(transform.get)(train_labels)
        test_labels = np.vectorize(transform.get)(test_labels)

        return train_labels, test_labels

    def __len__(self):
        return self.num_timeseries

    def _read_data(self):
        self.scaler = StandardScaler()

        self.train_data, self.train_labels, self.train_meta_data = load_from_tsfile(
            self.train_file_path_and_name, return_meta_data=True
        )
        self.test_data, self.test_labels, self.test_meta_data= load_from_tsfile(
            self.test_file_path_and_name, return_meta_data=True
        )

        self.train_labels, self.test_labels = self._transform_labels(
            self.train_labels, self.test_labels
        )

        assert self.train_meta_data["equallength"] == True, "Train data is not equallength"
        assert self.test_meta_data["equallength"] == True, "Test data is not equallength"
        assert self.train_meta_data["univariate"] == True, "Train data is not univariate"
        assert self.test_meta_data["univariate"] == True, "Test data is not univariate"

        # Flatten for scaling
        train_data_reshaped = self.train_data.reshape(-1, self.train_data.shape[2])
        test_data_reshaped = self.test_data.reshape(-1, self.test_data.shape[2])

        # Fit only on training data
        self.scaler.fit(train_data_reshaped)

        # Transform separately
        self.train_data = self.scaler.transform(train_data_reshaped)
        self.test_data = self.scaler.transform(test_data_reshaped)

        # Reshape back
        self.train_data = self.train_data.reshape(self.train_data.shape[0], self.train_data.shape[1])
        self.test_data = self.test_data.reshape(self.test_data.shape[0], self.test_data.shape[1])

        if self.data_split == "train":
            self.data = self.train_data
            self.labels = self.train_labels
        else:
            self.data = self.test_data
            self.labels = self.test_labels

        self.num_timeseries = self.data.shape[0]
        self.len_timeseries = self.data.shape[1]
        if self.len_timeseries > self.seq_len:
            self.len_timeseries = self.seq_len
        self.data = self.data.T  # Transpose if needed


    # def _read_data(self):
    #     self.scaler = StandardScaler()

    #     self.train_data, self.train_labels = load_from_tsfile(
    #         self.train_file_path_and_name
    #     )
    #     self.test_data, self.test_labels = load_from_tsfile(
    #         self.test_file_path_and_name
    #     )

    #     self.train_labels, self.test_labels = self._transform_labels(
    #         self.train_labels, self.test_labels
    #     )

    #     if self.data_split == "train":
    #         self.data = self.train_data
    #         self.labels = self.train_labels
    #     else:
    #         self.data = self.test_data
    #         self.labels = self.test_labels

    #     self.num_timeseries = self.data.shape[0]
    #     self.len_timeseries = self.data.shape[2]

    #     self.data = self.data.reshape(-1, self.len_timeseries)
    #     self.scaler.fit(self.data)
    #     self.data = self.scaler.transform(self.data)
    #     self.data = self.data.reshape(self.num_timeseries, self.len_timeseries)

    #     self.data = self.data.T


    def __getitem__(self, index):
        assert index < self.__len__()

        timeseries = self.data[:, index]
        timeseries_len = len(timeseries)
        labels = self.labels[index,].astype(int)
        input_mask = np.ones(self.seq_len)
        
        if self.seq_len >= timeseries_len:
            input_mask[: self.seq_len - timeseries_len] = 0
            timeseries = np.pad(timeseries, (self.seq_len - timeseries_len, 0))
        else:
            timeseries = timeseries[-self.seq_len:]

        return np.expand_dims(timeseries, axis=0), input_mask, labels
    
