from typing import Tuple, Optional, Dict, Any, Union, List, Callable
import os

import torch.utils.data
import numpy as np
import functools

from utils.metadata import DATA_DIRECTORY
from .dataset import DatasetMixin
from .preprocessing import minmax_scaler


FILENAMES = [
    'machine-1-1.txt',
    'machine-1-2.txt',
    'machine-1-3.txt',
    'machine-1-4.txt',
    'machine-1-5.txt',
    'machine-1-6.txt',
    'machine-1-7.txt',
    'machine-1-8.txt',
    'machine-2-1.txt',
    'machine-2-2.txt',
    'machine-2-3.txt',
    'machine-2-4.txt',
    'machine-2-5.txt',
    'machine-2-6.txt',
    'machine-2-7.txt',
    'machine-2-8.txt',
    'machine-2-9.txt',
    'machine-3-1.txt',
    'machine-3-10.txt',
    'machine-3-11.txt',
    'machine-3-2.txt',
    'machine-3-3.txt',
    'machine-3-4.txt',
    'machine-3-5.txt',
    'machine-3-6.txt',
    'machine-3-7.txt',
    'machine-3-8.txt',
    'machine-3-9.txt'
]

TRAIN_LENS = [28479, 23694, 23702, 23706, 23705, 23688, 23697, 23698, 23693, 23699, 23688, 23689, 23688, 28743, 23696,
              23702, 28722, 28700, 23692, 28695, 23702, 23703, 23687, 23690, 28726, 28705, 28703, 28713]

TEST_LENS = [28479, 23694, 23703, 23707, 23706, 23689, 23697, 23699, 23694, 23700, 23689, 23689, 23689, 28743, 23696,
             23703, 28722, 28700, 23693, 28696, 23703, 23703, 23687, 23691, 28726, 28705, 28704, 28713]


class SMDDataset(torch.utils.data.Dataset, DatasetMixin):

    def __init__(self, server_id: int, path: str = os.path.join(DATA_DIRECTORY, 'smd'),
                 training: bool = True, standardize : Union[bool, Callable] = True):
        if not (0 <= server_id <= 27):
            raise ValueError(f'Server ID must be between 0 and 27! Given: {server_id}')

        self.server_id   = server_id
        self.path        = path
        self.training    = training
        self.standardize = standardize  # TODO

        self.inputs  = None
        self.targets = None

        if callable(standardize):
            with np.load(os.path.join(self.path, FILENAMES[server_id].split('.')[0] + '_stats.npz')) as d:
                stats = dict(d)
            self.standardize_fn = functools.partial(standardize, stats=stats)
        elif standardize:
            with np.load(os.path.join(self.path, FILENAMES[server_id].split('.')[0] + '_stats.npz')) as d:
                stats = dict(d)
            self.standardize_fn = functools.partial(minmax_scaler, stats=stats)
        else:
            self.standardize_fn = None

    def load_data(self) -> Tuple[np.ndarray, np.ndarray]:

        test_str = 'train' if self.training else 'test'

        filename = FILENAMES[self.server_id]

        data = np.genfromtxt(os.path.join(self.path, test_str, filename), dtype=np.float32, delimiter=',')

        if self.training:
            target = np.zeros(data.shape[0])
        else:
            target = np.genfromtxt(os.path.join(self.path, 'test_label', filename), dtype=np.float32, delimiter=',')

        if self.standardize_fn is not None:
            data = self.standardize_fn(data)

        return data, target

    def __getitem__(self, item: int) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        if not (0 <= item < len(self)):
            raise ValueError('Out of bounds')

        if self.inputs is None or self.targets is None:
            self.inputs, self.targets = self.load_data()

        return (torch.as_tensor(self.inputs),), (torch.as_tensor(self.targets),)

    def __len__(self) -> Optional[int]:
        return 1

    @property
    def seq_len(self) -> Union[int, List[int]]:
        if self.training:
            return TRAIN_LENS[self.server_id]
        else:
            return TEST_LENS[self.server_id]

    @property
    def num_features(self) -> int:
        return 38

    @staticmethod
    def get_default_pipeline() -> Dict[str, Dict[str, Any]]:
        return {}

    @staticmethod
    def get_feature_names():
        return [''] * 38


