import functools
import os
from typing import Tuple, Optional, Union, Callable, Dict, Any

import numpy as np
import pandas as pd
import torch

from utils.metadata import DATA_DIRECTORY
from .dataset import DatasetMixin
from .preprocessing import minmax_scaler


class SWaTDataset(torch.utils.data.Dataset, DatasetMixin):
    def __init__(self, path: str = os.path.join(DATA_DIRECTORY, 'swat', 'SWaT_A1A2_Dec_2015', 'Physical'),
                 training: bool = True, standardize: Union[bool, Callable] = True, remove_startup: bool = True):
        """
        Secure Water Treatment Dataset (Goh2016)

        :param path: Path where the files "SWaT_Dataset_Normal_v1.csv" and "SWaT_Dataset_Attack_v0.csv" are located.
        :param training: If True, this will load the training set consisting only of normal samples. Otherwise loads
            the test set, which includes attacks.
        :param standardize: If True, apply min-max scaling (based on the training set). This can also be a function
            that accepts a DataFrame as its positional argument and a keyword argument `stats`: a dictionary of training
            data statistics.
        :param remove_startup: If True, this will remove the first 5 hours from the training set, as during this time
            the system was starting from an empty state. To be more exact, this removes only 4.5 hours, since the first 30
            minutes were already removed in v1 of the Dataset.
        """
        self.path = path
        self.training = training
        self.remove_startup = remove_startup

        self.inputs = None
        self.targets = None

        if callable(standardize):
            with np.load(os.path.join(self.path, 'SWaT_Dataset_Normal_v1_stats.npz')) as d:
                stats = dict(d)
            self.standardize_fn = functools.partial(standardize, stats=stats)
        elif standardize:
            with np.load(os.path.join(self.path, 'SWaT_Dataset_Normal_v1_stats.npz')) as d:
                stats = dict(d)
            self.standardize_fn = functools.partial(minmax_scaler, stats=stats)
        else:
            self.standardize_fn = None

    def load_data(self) -> Tuple[np.ndarray, np.ndarray]:
        test_str = 'Normal_v1' if self.training else 'Attack_v0'

        fname = f'SWaT_Dataset_{test_str}.csv'
        data = pd.read_csv(os.path.join(self.path, fname))

        # Convert string to int label
        data['Normal/Attack'] = (data['Normal/Attack'] == 'Attack').astype(np.int64)

        if self.standardize_fn is not None:
            data[data.columns[1:-1]] = self.standardize_fn(data[data.columns[1:-1]])
        data[data.columns[1:-1]] = data[data.columns[1:-1]].astype(np.float32)

        targets = data['Normal/Attack'].to_numpy()

        # Remove meta data
        data = data[data.columns[1:-1]]

        inputs = data.to_numpy()
        del data

        if self.training and self.remove_startup:
            inputs = inputs[12600:]
            targets = targets[12600:]

        return inputs, targets

    def __getitem__(self, item: int) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        if not (0 <= item < len(self)):
            raise ValueError('Out of bounds')

        if self.inputs is None or self.targets is None:
            self.inputs, self.targets = self.load_data()

        return (torch.as_tensor(self.inputs),), (torch.as_tensor(self.targets),)

    def __len__(self) -> Optional[int]:
        return 1

    @property
    def seq_len(self) -> Optional[int]:
        if self.training:
            if not self.remove_startup:
                return 495000
            else:
                return 482400
        else:
            return 449919

    @property
    def num_features(self) -> int:
        return 51

    @staticmethod
    def get_default_pipeline() -> Dict[str, Dict[str, Any]]:
        return {
            'subsample': {'class': 'SubsampleTransform', 'args': {'subsampling_factor': 5, 'aggregation': 'first'}},
            'cache': {'class': 'CacheTransform', 'args': {}}
        }

    @staticmethod
    def get_feature_names():
        return ['FIT101','LIT101','MV101','P101','P102','AIT201','AIT202','AIT203','FIT201','MV201','P201','P202',
                'P203','P204','P205','P206','DPIT301','FIT301','LIT301','MV301','MV302','MV303','MV304','P301','P302',
                'AIT401','AIT402','FIT401','LIT401','P401','P402','P403','P404','UV401','AIT501','AIT502','AIT503',
                'AIT504','FIT501','FIT502','FIT503','FIT504','P501','P502','PIT501','PIT502','PIT503','FIT601','P601',
                'P602','P603']