import os
import csv
import ast
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset, DataLoader
from utils import *


class TimeSeriesDataset(Dataset):
    """
    PyTorch Dataset for SMD, SMAP, and MSL time series data.

    Args:
        dataset_name (str): 'SMD', 'SMAP', or 'MSL'.
        root_dir (str): Base directory containing the raw datasets.
        train (bool): If True, load training data; if False, load test data.
    """
    def __init__(self, dataset_name, root_dir='./datasets', train=True):
        self.dataset_name = dataset_name
        self.train = train
        if dataset_name == 'PSM':
            self.data, self.labels = self._load_psm(root_dir)
        elif dataset_name == 'SMD':
            self.data, self.labels = self._load_smd(root_dir)
        elif dataset_name in ['SMAP', 'MSL']:
            self.data, self.labels = self._load_nasa(root_dir)
        elif dataset_name == 'SWaT':
            self.data, self.labels = self._load_swat(root_dir)
        elif dataset_name == 'WADI':
            self.data, self.labels = self._load_wadi(root_dir)
        else:
            raise ValueError(f'Unknown dataset: {dataset_name} (choose from  PSM, MSL, SMAP, SMD, SWaT, WADI)')
        
        if self.train:
            self.__class__.scaler = MinMaxScaler()
            self.data = self.__class__.scaler.fit_transform(self.data)
        else:
            self.__class__.scaler = getattr(self.__class__, 'scaler', None)
            if self.__class__.scaler is None:
                raise RuntimeError("Training Dataset must be initialized before the test Dataset to fit the scaler.")
            self.data = self.__class__.scaler.transform(self.data)

        if train == True:
            self.train_len = self.data.shape[0]
        else:
            self.test_len  = self.data.shape[0]
        self.input_dim = self.data.shape[1]

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]
    
    def inverse_transform(self, data):
        return self.__class__.scaler.inverse_transform(data)
    
    def _load_psm(self, root_dir):
        folder_path = os.path.join(root_dir, 'PSM')
        excel_name = 'train.csv' if self.train else 'test.csv'
        excel_path = os.path.join(folder_path, excel_name)
        npy_path = excel_path.replace(".csv", ".npy")

        if os.path.isfile(npy_path):
            arr = np.load(npy_path, allow_pickle=True)
            data = arr if self.train else arr[:, :-1]
            labels = None if self.train else arr[:, -1]
            return data, labels
        
        df = pd.read_csv(excel_path)
        if self.train:
            df = df.drop(df.columns[0], axis=1)
        else:
            df = df.drop(df.columns[0], axis=1)
            label_file = os.path.join(folder_path, 'test_label.csv')
            label = pd.read_csv(label_file)
            df['label'] = label['label']

        df = df.dropna(axis=1, how='all')
        df = df.dropna(axis=0, how='any')

        arr = df.values
        np.save(npy_path, arr)

        data = arr if self.train else arr[:, :-1]
        labels = None if self.train else arr[:, -1]
        return data, labels

    def _load_smd(self, root_dir):
        folder_path = os.path.join(root_dir, 'ServerMachineDataset', 'train' if self.train else 'test')
        files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]
        data_list = []
        for fname in files:
            arr = np.genfromtxt(os.path.join(folder_path, fname), dtype=np.float32, delimiter=',')
            data_list.append(arr)
        data = np.concatenate(data_list, axis=0)
        if self.train:
            labels = None
        else:
            label_folder = os.path.join(root_dir, 'ServerMachineDataset', 'test_label')
            label_files = [f for f in os.listdir(label_folder) if f.endswith('.txt')]
            label_list = []
            for fname in label_files:
                lbl = np.genfromtxt(os.path.join(label_folder, fname), dtype=np.float32, delimiter=',')
                label_list.append(lbl)
            labels = np.concatenate(label_list, axis=0)

        return data, labels

    def _load_nasa(self, root_dir):
        folder_path = os.path.join(root_dir, 'nasa', 'data')
        csv_path = os.path.join(folder_path, 'labeled_anomalies.csv')
        with open(csv_path, 'r') as f:
            reader = csv.reader(f)
            next(reader)
            rows = [r for r in reader if r[1] == self.dataset_name and r[0] != 'P-2']
        rows.sort(key=lambda x: x[0])

        data_list, label_list = [], []
        for r in rows:
            fname = r[0]
            length = int(r[-1])
            cat = 'train' if self.train else 'test'
            arr = np.load(os.path.join(folder_path, cat, f'{fname}.npy'))
            data_list.append(arr)
            if not self.train:
                anomalies = ast.literal_eval(r[2])
                lbl = np.zeros(length, dtype=np.bool_)
                for a in anomalies:
                    lbl[a[0]:a[1] + 1] = True
                label_list.append(lbl)

        data = np.concatenate(data_list, axis=0)
        labels = np.concatenate(label_list, axis=0) if label_list else None
        return data, labels
    
    def _load_swat(self, root_dir):
        folder_path = os.path.join(root_dir, 'SWaT')
        excel_name = 'SWaT_Dataset_Normal_v0.xlsx' if self.train else 'SWaT_Dataset_Attack_v0.xlsx'
        excel_path = os.path.join(folder_path, excel_name)
        npy_path = excel_path.replace(".xlsx", ".npy")

        if os.path.isfile(npy_path):
            arr = np.load(npy_path, allow_pickle=True)
            data = arr[:, :-1].astype(np.float32)
            labels = None if self.train else arr[:, -1]
            return data, labels

        df = pd.read_excel(excel_path, skiprows=1, engine="openpyxl")
        df = df.drop(df.columns[0], axis=1)

        if not self.train:
            df[df.columns[-1]] = df[df.columns[-1]].map({"Normal": 0, "Attack": 1, "A ttack": 1})

        arr = df.values
        np.save(npy_path, arr)

        data = arr[:, :-1].astype(np.float32)
        labels =  None if self.train else arr[:, -1]
        return data, labels
    
    def _load_wadi(self, root_dir):
        folder_path = os.path.join(root_dir, 'WADI')
        excel_name = 'WADI_14days_new.csv' if self.train else 'WADI_attackdataLABLE.csv'
        excel_path = os.path.join(folder_path, excel_name)
        npy_path = excel_path.replace(".csv", ".npy")

        if os.path.isfile(npy_path):
            arr = np.load(npy_path, allow_pickle=True)
            data = arr if self.train else arr[:, :-1]
            labels = None if self.train else arr[:, -1]
            return data, labels

        df = pd.read_csv(excel_path)
        if self.train:
            df = df.drop(df.columns[0:3], axis=1)
        else:
            df = df.drop(df.columns[0:2], axis=1)
        df = df.dropna(axis=1, how='all')
        df = df.dropna(axis=0, how='any')

        if not self.train:
            df[df.columns[-1]] = df[df.columns[-1]].replace({1: 0, -1: 1})

        arr = df.values
        np.save(npy_path, arr)

        data = arr if self.train else arr[:, :-1]
        labels =  None if self.train else arr[:, -1]
        return data, labels


class TimeSeriesLoader:
    def __init__(self, dataset_name: str, root_dir: str = './datasets', window_size: int = 100,
                 step_size: int = 10, batch_size: int = 64, num_workers: int = 0, device=None):
        
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.window_size = window_size
        self.step_size = step_size
        self.dataset_name = dataset_name

        self.train_ds = TimeSeriesDataset(self.dataset_name, root_dir, train=True)
        self.test_ds  = TimeSeriesDataset(self.dataset_name, root_dir, train=False)

        self.train_len = self.train_ds.train_len
        self.test_len  = self.test_ds.test_len
        self.input_dim = self.train_ds.input_dim

        train_windows = self.rolling_windows(self.train_ds.data)
        test_windows  = self.rolling_windows(self.test_ds.data)

        self.train_window_len = train_windows.shape[0]
        self.test_window_len  = test_windows.shape[0]

        self.test_labels = self.test_ds.labels

        self.train_loader = DataLoader(train_windows,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=num_workers,
                                       drop_last=False)
        self.test_loader  = DataLoader(test_windows,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=num_workers,
                                       drop_last=False)

    def rolling_windows(self, data):
        """
        series → overlapping windows
        input: np.ndarray or pd.DataFrame
        output: torch.Tensor (n_windows, window_size, n_features)
        """
        if isinstance(data, pd.DataFrame):
            data = torch.tensor(data.values, dtype=torch.float32).to(self.device)
        else:
            data = torch.tensor(data, dtype=torch.float32).to(self.device)

        n, feat = data.shape
        windows = []
        for start in range(0, n - self.window_size + 1, self.step_size):
            windows.append(data[start:start + self.window_size])

        if (start + self.window_size) < n:
            windows.append(data[-self.window_size:])

        return torch.stack(windows)

    def unroll_windows(self, windows, data_type="train"):
        """
        windows: np.ndarray or torch.Tensor (n_wins, win, feat)
        """
        if isinstance(windows, torch.Tensor):
            windows = windows.numpy()

        if data_type == "train":
            L = self.train_len
        elif data_type == "test":
            L = self.test_len
        else:
            raise ValueError("data_type must be 'train' or 'test'.")

        n_wins, window_size, feat = windows.shape
        recon = np.zeros((L, feat), dtype=float)
        counts = np.zeros(L, dtype=int)

        for i in range(n_wins):
            start = i * self.step_size
            end = start + window_size
            if end > L:
                start = L - window_size
                end = L
            recon[start:end] += windows[i]
            counts[start:end] += 1

        for i in range(L):
            if counts[i] > 0:
                recon[i] /= counts[i]

        return recon

if __name__ == '__main__':
    # dataloader = TimeSeriesLoader('WADI', window_size=100, step_size=10, batch_size=64, num_workers=0)
    train_psm = TimeSeriesDataset('PSM', train=True)
    test_psm = TimeSeriesDataset('PSM', train=False)
    train_smd  = TimeSeriesDataset('SMD', train=True)
    test_smd  = TimeSeriesDataset('SMD', train=False)
    train_smap = TimeSeriesDataset('SMAP', train=True)
    test_smap = TimeSeriesDataset('SMAP', train=False)
    train_msl  = TimeSeriesDataset('MSL', train=True)
    test_msl  = TimeSeriesDataset('MSL', train=False)
    train_swat = TimeSeriesDataset('SWaT', train=True)
    test_swat = TimeSeriesDataset('SWaT', train=False)
    train_wadi = TimeSeriesDataset('WADI', train=True)
    test_wadi = TimeSeriesDataset('WADI', train=False)

    def dataset_info(train_data, test_data):
        dimension = train_data.input_dim
        train_set = train_data.train_len
        test_set = test_data.test_len
        anomaly_ratio = test_data.labels.sum() / len(test_data.labels) * 100
        return (dimension, train_set, test_set, anomaly_ratio)
    
    psm_info  = dataset_info(train_psm, test_psm)
    msl_info  = dataset_info(train_msl, test_msl)
    smap_info = dataset_info(train_smap, test_smap)
    smd_info  = dataset_info(train_smd, test_smd)
    swat_info = dataset_info(train_swat, test_swat)
    wadi_info = dataset_info(train_wadi, test_wadi)

    print(f'PSM |Dimensions: {psm_info[0]}, Train set: {psm_info[1]}, Test set: {psm_info[2]}, Anomaly Ratio: {psm_info[3]:.2f}%')
    print(f'MSL |Dimensions: {msl_info[0]}, Train set: {msl_info[1]}, Test set: {msl_info[2]}, Anomaly Ratio: {msl_info[3]:.2f}%')
    print(f'SMAP|Dimensions: {smap_info[0]}, Train set: {smap_info[1]}, Test set: {smap_info[2]}, Anomaly Ratio: {smap_info[3]:.2f}%')
    print(f'SMD |Dimensions: {smd_info[0]}, Train set: {smd_info[1]}, Test set: {smd_info[2]}, Anomaly Ratio: {smd_info[3]:.2f}%')
    print(f'SWaT|Dimensions: {swat_info[0]}, Train set: {swat_info[1]}, Test set: {swat_info[2]}, Anomaly Ratio: {swat_info[3]:.2f}%')
    print(f'WADI|Dimensions: {wadi_info[0]}, Train set: {wadi_info[1]}, Test set: {wadi_info[2]}, Anomaly Ratio: {wadi_info[3]:.2f}%')