import os
from torch.utils.data import DataLoader
from utils.utils import Color
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
# Some code based on https://github.com/thuml/Time-Series-Library


class PSMSegLoader(object):
    """
    Segment-wise data loader for the PSM dataset.
    Performs standardization and sliding-window slicing for training, validation, and test modes.
    """
    def __init__(self, data_path, win_size, step, mode="train"):
        self.mode = mode
        self.step = step
        self.win_size = win_size
        self.scaler = StandardScaler()
        data = pd.read_csv(data_path + '/train.csv')
        data = data.values[:, 1:]
        data = np.nan_to_num(data)
        self.scaler.fit(data)
        data = self.scaler.transform(data)
        test_data = pd.read_csv(data_path + '/test.csv')
        test_data = test_data.values[:, 1:]
        test_data = np.nan_to_num(test_data)
        self.test = self.scaler.transform(test_data)
        # Split train data into train (80%) and validation (20%)
        self.train = data[:(int)(len(data) * 0.8)]
        self.val = data[(int)(len(data) * 0.8):]
        self.test_labels = pd.read_csv(data_path + '/test_label.csv').values[:, 1:]
        # Print shape of selected split
        if self.mode == "train":
            print(Color.GREEN + "train:", self.train.shape , Color.RESET)
        elif (self.mode == 'val'):
            print(Color.GREEN + "vaildation:", self.val.shape , Color.RESET)
        elif (self.mode == 'test'):
            print(Color.GREEN + "test:", self.test.shape , Color.RESET)

    def __len__(self):
        # Return number of sliding windows for each mode
        if self.mode == "train":
            return (self.train.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'val'):
            return (self.val.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'test'):
            return (self.test.shape[0] - self.win_size) // self.win_size + 1

    def __getitem__(self, index):
        index = index * self.step
        if self.mode == "train":
            return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'val'):
            return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'test'):
            return np.float32(self.test[
                              index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
                self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])


class MSLSegLoader(object):
    """
    Segment-wise data loader for the MSL dataset.
    Performs standardization and sliding-window slicing for training, validation, and test modes.
    """
    def __init__(self, data_path, win_size, step, mode="train"):
        self.mode = mode
        self.step = step
        self.win_size = win_size
        self.scaler = StandardScaler()
        data = np.load(data_path + "/MSL_train.npy")
        self.scaler.fit(data)
        data = self.scaler.transform(data)
        test_data = np.load(data_path + "/MSL_test.npy")
        self.test = self.scaler.transform(test_data)
        # Split train data into train (80%) and validation (20%)
        self.train = data[:(int)(len(data) * 0.8)]
        self.val = data[(int)(len(data) * 0.8):]
        self.test_labels = np.load(data_path + "/MSL_test_label.npy")
        # Print shape of selected split
        if self.mode == "train":
            print(Color.GREEN + "train:", self.train.shape , Color.RESET)
        elif (self.mode == 'val'):
            print(Color.GREEN + "vaildation:", self.val.shape , Color.RESET)
        elif (self.mode == 'test'):
            print(Color.GREEN + "test:", self.test.shape , Color.RESET)

    def __len__(self):
        # Return number of sliding windows for each mode
        if self.mode == "train":
            return (self.train.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'val'):
            return (self.val.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'test'):
            return (self.test.shape[0] - self.win_size) // self.win_size + 1

    def __getitem__(self, index):
        index = index * self.step
        if self.mode == "train":
            return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'val'):
            return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'test'):
            return np.float32(self.test[
                              index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
                self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])


class SMAPSegLoader(object):
    """
    Segment-wise data loader for the SMAP dataset.
    Performs standardization and sliding-window slicing for training, validation, and test modes.
    """
    def __init__(self, data_path, win_size, step, mode="train"):
        self.mode = mode
        self.step = step
        self.win_size = win_size
        self.scaler = StandardScaler()
        data = np.load(data_path + "/SMAP_train.npy")
        self.scaler.fit(data)
        data = self.scaler.transform(data)
        test_data = np.load(data_path + "/SMAP_test.npy")
        self.test = self.scaler.transform(test_data)
        # Split train data into train (80%) and validation (20%)
        self.train = data[:(int)(len(data) * 0.8)]
        self.val = data[(int)(len(data) * 0.8):]
        self.test_labels = np.load(data_path + "/SMAP_test_label.npy")
        # Print shape of selected split
        if self.mode == "train":
            print(Color.GREEN + "train:", self.train.shape , Color.RESET)
        elif (self.mode == 'val'):
            print(Color.GREEN + "vaildation:", self.val.shape , Color.RESET)
        elif (self.mode == 'test'):
            print(Color.GREEN + "test:", self.test.shape , Color.RESET)

    def __len__(self):
        # Return number of sliding windows for each mode
        if self.mode == "train":
            return (self.train.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'val'):
            return (self.val.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'test'):
            return (self.test.shape[0] - self.win_size) // self.win_size + 1

    def __getitem__(self, index):
        index = index * self.step
        if self.mode == "train":
            return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'val'):
            return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'test'):
            return np.float32(self.test[
                              index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
                self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])


class SMDSegLoader(object):
    """
    Segment-wise data loader for the SMD dataset.
    Performs standardization and sliding-window slicing for training, validation, and test modes.
    """
    def __init__(self, data_path, win_size, step, mode="train"):
        self.mode = mode
        self.step = step
        self.win_size = win_size
        self.scaler = StandardScaler()
        data = np.load(data_path + "/SMD_train.npy")
        self.scaler.fit(data)
        data = self.scaler.transform(data)
        test_data = np.load(data_path + "/SMD_test.npy")
        self.test = self.scaler.transform(test_data)
        # Split train data into train (80%) and validation (20%)
        self.train = data[:(int)(len(data) * 0.8)]
        self.val = data[(int)(len(data) * 0.8):]
        self.test_labels = np.load(data_path + "/SMD_test_label.npy")
        # Print shape of selected split
        if self.mode == "train":
            print(Color.GREEN + "train:", self.train.shape , Color.RESET)
        elif (self.mode == 'val'):
            print(Color.GREEN + "vaildation:", self.val.shape , Color.RESET)
        elif (self.mode == 'test'):
            print(Color.GREEN + "test:", self.test.shape , Color.RESET)

    def __len__(self):
        # Return number of sliding windows for each mode
        if self.mode == "train":
            return (self.train.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'val'):
            return (self.val.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'test'):
            return (self.test.shape[0] - self.win_size) // self.win_size + 1

    def __getitem__(self, index):
        index = index * self.step
        if self.mode == "train":
            return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'val'):
            return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'test'):
            return np.float32(self.test[
                              index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
                self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])


class SWATSegLoader(object):
    """
    Segment-wise data loader for the SWaT dataset.
    Performs standardization and sliding-window slicing for training, validation, and test modes.
    """
    def __init__(self, root_path, win_size, step=1, mode="train"):
        self.mode = mode
        self.step = step
        self.win_size = win_size
        self.scaler = StandardScaler()
        train_data = pd.read_csv(os.path.join(root_path, 'swat_train2.csv'))
        test_data = pd.read_csv(os.path.join(root_path, 'swat2.csv'))
        labels = test_data.values[:, -1:]
        train_data = train_data.values[:, :-1]
        test_data = test_data.values[:, :-1]
        self.scaler.fit(train_data)
        train_data = self.scaler.transform(train_data)
        test_data = self.scaler.transform(test_data)
        # Split train data into train (80%) and validation (20%)
        self.train = train_data[:(int)(len(train_data) * 0.8)]
        self.val = train_data[(int)(len(train_data) * 0.8):]
        self.test = test_data
        self.test_labels = labels
        # Print shape of selected split
        if self.mode == "train":
            print(Color.GREEN + "train:", self.train.shape , Color.RESET)
        elif (self.mode == 'val'):
            print(Color.GREEN + "vaildation:", self.val.shape , Color.RESET)
        elif (self.mode == 'test'):
            print(Color.GREEN + "test:", self.test.shape , Color.RESET)

    def __len__(self):
        # Return number of sliding windows for each mode
        if self.mode == "train":
            return (self.train.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'val'):
            return (self.val.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'test'):
            return (self.test.shape[0] - self.win_size) // self.win_size + 1

    def __getitem__(self, index):
        index = index * self.step
        if self.mode == "train":
            return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'val'):
            return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'test'):
            return np.float32(self.test[
                              index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
                self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])


class SynSegLoader(object):
    """
    Segment-wise data loader for the TODS dataset.
    Performs standardization and sliding-window slicing for training, validation, and test modes.
    """
    def __init__(self, data_path, win_size, step, mode="train"):
        self.mode = mode
        self.step = step
        self.win_size = win_size
        self.scaler = StandardScaler()
        data = pd.read_csv(data_path + '/train_tods.csv')
        data = data.values
        data = np.nan_to_num(data)
        self.scaler.fit(data)
        data = self.scaler.transform(data)
        test_df = pd.read_csv(data_path + '/test_tods.csv')
        test_data = test_df.values[:, :-1]
        test_data = np.nan_to_num(test_data)
        self.test = self.scaler.transform(test_data)
        # Split train data into train (80%) and validation (20%)
        self.train = data[:(int)(len(data) * 0.8)]
        self.val = data[(int)(len(data) * 0.8):]
        self.test_labels = test_df.values[:, -1:]
        # Print shape of selected split
        if self.mode == "train":
            print(Color.GREEN + "train:", self.train.shape , Color.RESET)
        elif (self.mode == 'val'):
            print(Color.GREEN + "vaildation:", self.val.shape , Color.RESET)
        elif (self.mode == 'test'):
            print(Color.GREEN + "test:", self.test.shape , Color.RESET)

    def __len__(self):
        # Return number of sliding windows for each mode
        if self.mode == "train":
            return (self.train.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'val'):
            return (self.val.shape[0] - self.win_size) // self.step + 1
        elif (self.mode == 'test'):
            return (self.test.shape[0] - self.win_size) // self.win_size + 1

    def __getitem__(self, index):
        index = index * self.step
        if self.mode == "train":
            return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'val'):
            return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
        elif (self.mode == 'test'):
            return np.float32(self.test[
                              index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
                self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])


def get_loader_segment(data_path, batch_size, win_size=100, step=100, mode='train', dataset='KDD'):
    if (dataset == 'SMD'):
        dataset = SMDSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'MSL'):
        dataset = MSLSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'SMAP'):
        dataset = SMAPSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'PSM'):
        dataset = PSMSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'SWaT'):
        dataset = SWATSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'NeurIPSTS'):
        dataset = SynSegLoader(data_path, win_size, step, mode)

    shuffle = False
    if mode == 'train':
        shuffle = True

    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=8,
                             pin_memory=True)
    return data_loader